sqlite.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579
  1. #! /usr/bin/env python3
  2. #
  3. # Copyright (C) 2023 Garmin Ltd.
  4. #
  5. # SPDX-License-Identifier: GPL-2.0-only
  6. #
  7. from datetime import datetime, timezone
  8. import sqlite3
  9. import logging
  10. from contextlib import closing
  11. from . import User
  12. logger = logging.getLogger("hashserv.sqlite")
  13. UNIHASH_TABLE_DEFINITION = (
  14. ("method", "TEXT NOT NULL", "UNIQUE"),
  15. ("taskhash", "TEXT NOT NULL", "UNIQUE"),
  16. ("unihash", "TEXT NOT NULL", ""),
  17. ("gc_mark", "TEXT NOT NULL", ""),
  18. )
  19. UNIHASH_TABLE_COLUMNS = tuple(name for name, _, _ in UNIHASH_TABLE_DEFINITION)
  20. OUTHASH_TABLE_DEFINITION = (
  21. ("method", "TEXT NOT NULL", "UNIQUE"),
  22. ("taskhash", "TEXT NOT NULL", "UNIQUE"),
  23. ("outhash", "TEXT NOT NULL", "UNIQUE"),
  24. ("created", "DATETIME", ""),
  25. # Optional fields
  26. ("owner", "TEXT", ""),
  27. ("PN", "TEXT", ""),
  28. ("PV", "TEXT", ""),
  29. ("PR", "TEXT", ""),
  30. ("task", "TEXT", ""),
  31. ("outhash_siginfo", "TEXT", ""),
  32. )
  33. OUTHASH_TABLE_COLUMNS = tuple(name for name, _, _ in OUTHASH_TABLE_DEFINITION)
  34. USERS_TABLE_DEFINITION = (
  35. ("username", "TEXT NOT NULL", "UNIQUE"),
  36. ("token", "TEXT NOT NULL", ""),
  37. ("permissions", "TEXT NOT NULL", ""),
  38. )
  39. USERS_TABLE_COLUMNS = tuple(name for name, _, _ in USERS_TABLE_DEFINITION)
  40. CONFIG_TABLE_DEFINITION = (
  41. ("name", "TEXT NOT NULL", "UNIQUE"),
  42. ("value", "TEXT", ""),
  43. )
  44. CONFIG_TABLE_COLUMNS = tuple(name for name, _, _ in CONFIG_TABLE_DEFINITION)
  45. def adapt_datetime_iso(val):
  46. """Adapt datetime.datetime to UTC ISO 8601 date."""
  47. return val.astimezone(timezone.utc).isoformat()
  48. sqlite3.register_adapter(datetime, adapt_datetime_iso)
  49. def convert_datetime(val):
  50. """Convert ISO 8601 datetime to datetime.datetime object."""
  51. return datetime.fromisoformat(val.decode())
  52. sqlite3.register_converter("DATETIME", convert_datetime)
  53. def _make_table(cursor, name, definition):
  54. cursor.execute(
  55. """
  56. CREATE TABLE IF NOT EXISTS {name} (
  57. id INTEGER PRIMARY KEY AUTOINCREMENT,
  58. {fields}
  59. UNIQUE({unique})
  60. )
  61. """.format(
  62. name=name,
  63. fields=" ".join("%s %s," % (name, typ) for name, typ, _ in definition),
  64. unique=", ".join(
  65. name for name, _, flags in definition if "UNIQUE" in flags
  66. ),
  67. )
  68. )
  69. def map_user(row):
  70. if row is None:
  71. return None
  72. return User(
  73. username=row["username"],
  74. permissions=set(row["permissions"].split()),
  75. )
  76. def _make_condition_statement(columns, condition):
  77. where = {}
  78. for c in columns:
  79. if c in condition and condition[c] is not None:
  80. where[c] = condition[c]
  81. return where, " AND ".join("%s=:%s" % (k, k) for k in where.keys())
  82. def _get_sqlite_version(cursor):
  83. cursor.execute("SELECT sqlite_version()")
  84. version = []
  85. for v in cursor.fetchone()[0].split("."):
  86. try:
  87. version.append(int(v))
  88. except ValueError:
  89. version.append(v)
  90. return tuple(version)
  91. def _schema_table_name(version):
  92. if version >= (3, 33):
  93. return "sqlite_schema"
  94. return "sqlite_master"
  95. class DatabaseEngine(object):
  96. def __init__(self, dbname, sync):
  97. self.dbname = dbname
  98. self.logger = logger
  99. self.sync = sync
  100. async def create(self):
  101. db = sqlite3.connect(self.dbname)
  102. db.row_factory = sqlite3.Row
  103. with closing(db.cursor()) as cursor:
  104. _make_table(cursor, "unihashes_v3", UNIHASH_TABLE_DEFINITION)
  105. _make_table(cursor, "outhashes_v2", OUTHASH_TABLE_DEFINITION)
  106. _make_table(cursor, "users", USERS_TABLE_DEFINITION)
  107. _make_table(cursor, "config", CONFIG_TABLE_DEFINITION)
  108. cursor.execute("PRAGMA journal_mode = WAL")
  109. cursor.execute(
  110. "PRAGMA synchronous = %s" % ("NORMAL" if self.sync else "OFF")
  111. )
  112. # Drop old indexes
  113. cursor.execute("DROP INDEX IF EXISTS taskhash_lookup")
  114. cursor.execute("DROP INDEX IF EXISTS outhash_lookup")
  115. cursor.execute("DROP INDEX IF EXISTS taskhash_lookup_v2")
  116. cursor.execute("DROP INDEX IF EXISTS outhash_lookup_v2")
  117. cursor.execute("DROP INDEX IF EXISTS taskhash_lookup_v3")
  118. # TODO: Upgrade from tasks_v2?
  119. cursor.execute("DROP TABLE IF EXISTS tasks_v2")
  120. # Create new indexes
  121. cursor.execute(
  122. "CREATE INDEX IF NOT EXISTS taskhash_lookup_v4 ON unihashes_v3 (method, taskhash)"
  123. )
  124. cursor.execute(
  125. "CREATE INDEX IF NOT EXISTS unihash_lookup_v1 ON unihashes_v3 (unihash)"
  126. )
  127. cursor.execute(
  128. "CREATE INDEX IF NOT EXISTS outhash_lookup_v3 ON outhashes_v2 (method, outhash)"
  129. )
  130. cursor.execute("CREATE INDEX IF NOT EXISTS config_lookup ON config (name)")
  131. sqlite_version = _get_sqlite_version(cursor)
  132. cursor.execute(
  133. f"""
  134. SELECT name FROM {_schema_table_name(sqlite_version)} WHERE type = 'table' AND name = 'unihashes_v2'
  135. """
  136. )
  137. if cursor.fetchone():
  138. self.logger.info("Upgrading Unihashes V2 -> V3...")
  139. cursor.execute(
  140. """
  141. INSERT INTO unihashes_v3 (id, method, unihash, taskhash, gc_mark)
  142. SELECT id, method, unihash, taskhash, '' FROM unihashes_v2
  143. """
  144. )
  145. cursor.execute("DROP TABLE unihashes_v2")
  146. db.commit()
  147. self.logger.info("Upgrade complete")
  148. def connect(self, logger):
  149. return Database(logger, self.dbname, self.sync)
  150. class Database(object):
  151. def __init__(self, logger, dbname, sync):
  152. self.dbname = dbname
  153. self.logger = logger
  154. self.db = sqlite3.connect(self.dbname)
  155. self.db.row_factory = sqlite3.Row
  156. with closing(self.db.cursor()) as cursor:
  157. cursor.execute("PRAGMA journal_mode = WAL")
  158. cursor.execute(
  159. "PRAGMA synchronous = %s" % ("NORMAL" if sync else "OFF")
  160. )
  161. self.sqlite_version = _get_sqlite_version(cursor)
  162. async def __aenter__(self):
  163. return self
  164. async def __aexit__(self, exc_type, exc_value, traceback):
  165. await self.close()
  166. async def _set_config(self, cursor, name, value):
  167. cursor.execute(
  168. """
  169. INSERT OR REPLACE INTO config (id, name, value) VALUES
  170. ((SELECT id FROM config WHERE name=:name), :name, :value)
  171. """,
  172. {
  173. "name": name,
  174. "value": value,
  175. },
  176. )
  177. async def _get_config(self, cursor, name):
  178. cursor.execute(
  179. "SELECT value FROM config WHERE name=:name",
  180. {
  181. "name": name,
  182. },
  183. )
  184. row = cursor.fetchone()
  185. if row is None:
  186. return None
  187. return row["value"]
  188. async def close(self):
  189. self.db.close()
  190. async def get_unihash_by_taskhash_full(self, method, taskhash):
  191. with closing(self.db.cursor()) as cursor:
  192. cursor.execute(
  193. """
  194. SELECT *, unihashes_v3.unihash AS unihash FROM outhashes_v2
  195. INNER JOIN unihashes_v3 ON unihashes_v3.method=outhashes_v2.method AND unihashes_v3.taskhash=outhashes_v2.taskhash
  196. WHERE outhashes_v2.method=:method AND outhashes_v2.taskhash=:taskhash
  197. ORDER BY outhashes_v2.created ASC
  198. LIMIT 1
  199. """,
  200. {
  201. "method": method,
  202. "taskhash": taskhash,
  203. },
  204. )
  205. return cursor.fetchone()
  206. async def get_unihash_by_outhash(self, method, outhash):
  207. with closing(self.db.cursor()) as cursor:
  208. cursor.execute(
  209. """
  210. SELECT *, unihashes_v3.unihash AS unihash FROM outhashes_v2
  211. INNER JOIN unihashes_v3 ON unihashes_v3.method=outhashes_v2.method AND unihashes_v3.taskhash=outhashes_v2.taskhash
  212. WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash
  213. ORDER BY outhashes_v2.created ASC
  214. LIMIT 1
  215. """,
  216. {
  217. "method": method,
  218. "outhash": outhash,
  219. },
  220. )
  221. return cursor.fetchone()
  222. async def unihash_exists(self, unihash):
  223. with closing(self.db.cursor()) as cursor:
  224. cursor.execute(
  225. """
  226. SELECT * FROM unihashes_v3 WHERE unihash=:unihash
  227. LIMIT 1
  228. """,
  229. {
  230. "unihash": unihash,
  231. },
  232. )
  233. return cursor.fetchone() is not None
  234. async def get_outhash(self, method, outhash):
  235. with closing(self.db.cursor()) as cursor:
  236. cursor.execute(
  237. """
  238. SELECT * FROM outhashes_v2
  239. WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash
  240. ORDER BY outhashes_v2.created ASC
  241. LIMIT 1
  242. """,
  243. {
  244. "method": method,
  245. "outhash": outhash,
  246. },
  247. )
  248. return cursor.fetchone()
  249. async def get_equivalent_for_outhash(self, method, outhash, taskhash):
  250. with closing(self.db.cursor()) as cursor:
  251. cursor.execute(
  252. """
  253. SELECT outhashes_v2.taskhash AS taskhash, unihashes_v3.unihash AS unihash FROM outhashes_v2
  254. INNER JOIN unihashes_v3 ON unihashes_v3.method=outhashes_v2.method AND unihashes_v3.taskhash=outhashes_v2.taskhash
  255. -- Select any matching output hash except the one we just inserted
  256. WHERE outhashes_v2.method=:method AND outhashes_v2.outhash=:outhash AND outhashes_v2.taskhash!=:taskhash
  257. -- Pick the oldest hash
  258. ORDER BY outhashes_v2.created ASC
  259. LIMIT 1
  260. """,
  261. {
  262. "method": method,
  263. "outhash": outhash,
  264. "taskhash": taskhash,
  265. },
  266. )
  267. return cursor.fetchone()
  268. async def get_equivalent(self, method, taskhash):
  269. with closing(self.db.cursor()) as cursor:
  270. cursor.execute(
  271. "SELECT taskhash, method, unihash FROM unihashes_v3 WHERE method=:method AND taskhash=:taskhash",
  272. {
  273. "method": method,
  274. "taskhash": taskhash,
  275. },
  276. )
  277. return cursor.fetchone()
  278. async def remove(self, condition):
  279. def do_remove(columns, table_name, cursor):
  280. where, clause = _make_condition_statement(columns, condition)
  281. if where:
  282. query = f"DELETE FROM {table_name} WHERE {clause}"
  283. cursor.execute(query, where)
  284. return cursor.rowcount
  285. return 0
  286. count = 0
  287. with closing(self.db.cursor()) as cursor:
  288. count += do_remove(OUTHASH_TABLE_COLUMNS, "outhashes_v2", cursor)
  289. count += do_remove(UNIHASH_TABLE_COLUMNS, "unihashes_v3", cursor)
  290. self.db.commit()
  291. return count
  292. async def get_current_gc_mark(self):
  293. with closing(self.db.cursor()) as cursor:
  294. return await self._get_config(cursor, "gc-mark")
  295. async def gc_status(self):
  296. with closing(self.db.cursor()) as cursor:
  297. cursor.execute(
  298. """
  299. SELECT COUNT() FROM unihashes_v3 WHERE
  300. gc_mark=COALESCE((SELECT value FROM config WHERE name='gc-mark'), '')
  301. """
  302. )
  303. keep_rows = cursor.fetchone()[0]
  304. cursor.execute(
  305. """
  306. SELECT COUNT() FROM unihashes_v3 WHERE
  307. gc_mark!=COALESCE((SELECT value FROM config WHERE name='gc-mark'), '')
  308. """
  309. )
  310. remove_rows = cursor.fetchone()[0]
  311. current_mark = await self._get_config(cursor, "gc-mark")
  312. return (keep_rows, remove_rows, current_mark)
  313. async def gc_mark(self, mark, condition):
  314. with closing(self.db.cursor()) as cursor:
  315. await self._set_config(cursor, "gc-mark", mark)
  316. where, clause = _make_condition_statement(UNIHASH_TABLE_COLUMNS, condition)
  317. new_rows = 0
  318. if where:
  319. cursor.execute(
  320. f"""
  321. UPDATE unihashes_v3 SET
  322. gc_mark=COALESCE((SELECT value FROM config WHERE name='gc-mark'), '')
  323. WHERE {clause}
  324. """,
  325. where,
  326. )
  327. new_rows = cursor.rowcount
  328. self.db.commit()
  329. return new_rows
  330. async def gc_sweep(self):
  331. with closing(self.db.cursor()) as cursor:
  332. # NOTE: COALESCE is not used in this query so that if the current
  333. # mark is NULL, nothing will happen
  334. cursor.execute(
  335. """
  336. DELETE FROM unihashes_v3 WHERE
  337. gc_mark!=(SELECT value FROM config WHERE name='gc-mark')
  338. """
  339. )
  340. count = cursor.rowcount
  341. await self._set_config(cursor, "gc-mark", None)
  342. self.db.commit()
  343. return count
  344. async def clean_unused(self, oldest):
  345. with closing(self.db.cursor()) as cursor:
  346. cursor.execute(
  347. """
  348. DELETE FROM outhashes_v2 WHERE created<:oldest AND NOT EXISTS (
  349. SELECT unihashes_v3.id FROM unihashes_v3 WHERE unihashes_v3.method=outhashes_v2.method AND unihashes_v3.taskhash=outhashes_v2.taskhash LIMIT 1
  350. )
  351. """,
  352. {
  353. "oldest": oldest,
  354. },
  355. )
  356. self.db.commit()
  357. return cursor.rowcount
  358. async def insert_unihash(self, method, taskhash, unihash):
  359. with closing(self.db.cursor()) as cursor:
  360. prevrowid = cursor.lastrowid
  361. cursor.execute(
  362. """
  363. INSERT OR IGNORE INTO unihashes_v3 (method, taskhash, unihash, gc_mark) VALUES
  364. (
  365. :method,
  366. :taskhash,
  367. :unihash,
  368. COALESCE((SELECT value FROM config WHERE name='gc-mark'), '')
  369. )
  370. """,
  371. {
  372. "method": method,
  373. "taskhash": taskhash,
  374. "unihash": unihash,
  375. },
  376. )
  377. self.db.commit()
  378. return cursor.lastrowid != prevrowid
  379. async def insert_outhash(self, data):
  380. data = {k: v for k, v in data.items() if k in OUTHASH_TABLE_COLUMNS}
  381. keys = sorted(data.keys())
  382. query = "INSERT OR IGNORE INTO outhashes_v2 ({fields}) VALUES({values})".format(
  383. fields=", ".join(keys),
  384. values=", ".join(":" + k for k in keys),
  385. )
  386. with closing(self.db.cursor()) as cursor:
  387. prevrowid = cursor.lastrowid
  388. cursor.execute(query, data)
  389. self.db.commit()
  390. return cursor.lastrowid != prevrowid
  391. def _get_user(self, username):
  392. with closing(self.db.cursor()) as cursor:
  393. cursor.execute(
  394. """
  395. SELECT username, permissions, token FROM users WHERE username=:username
  396. """,
  397. {
  398. "username": username,
  399. },
  400. )
  401. return cursor.fetchone()
  402. async def lookup_user_token(self, username):
  403. row = self._get_user(username)
  404. if row is None:
  405. return None, None
  406. return map_user(row), row["token"]
  407. async def lookup_user(self, username):
  408. return map_user(self._get_user(username))
  409. async def set_user_token(self, username, token):
  410. with closing(self.db.cursor()) as cursor:
  411. cursor.execute(
  412. """
  413. UPDATE users SET token=:token WHERE username=:username
  414. """,
  415. {
  416. "username": username,
  417. "token": token,
  418. },
  419. )
  420. self.db.commit()
  421. return cursor.rowcount != 0
  422. async def set_user_perms(self, username, permissions):
  423. with closing(self.db.cursor()) as cursor:
  424. cursor.execute(
  425. """
  426. UPDATE users SET permissions=:permissions WHERE username=:username
  427. """,
  428. {
  429. "username": username,
  430. "permissions": " ".join(permissions),
  431. },
  432. )
  433. self.db.commit()
  434. return cursor.rowcount != 0
  435. async def get_all_users(self):
  436. with closing(self.db.cursor()) as cursor:
  437. cursor.execute("SELECT username, permissions FROM users")
  438. return [map_user(r) for r in cursor.fetchall()]
  439. async def new_user(self, username, permissions, token):
  440. with closing(self.db.cursor()) as cursor:
  441. try:
  442. cursor.execute(
  443. """
  444. INSERT INTO users (username, token, permissions) VALUES (:username, :token, :permissions)
  445. """,
  446. {
  447. "username": username,
  448. "token": token,
  449. "permissions": " ".join(permissions),
  450. },
  451. )
  452. self.db.commit()
  453. return True
  454. except sqlite3.IntegrityError:
  455. return False
  456. async def delete_user(self, username):
  457. with closing(self.db.cursor()) as cursor:
  458. cursor.execute(
  459. """
  460. DELETE FROM users WHERE username=:username
  461. """,
  462. {
  463. "username": username,
  464. },
  465. )
  466. self.db.commit()
  467. return cursor.rowcount != 0
  468. async def get_usage(self):
  469. usage = {}
  470. with closing(self.db.cursor()) as cursor:
  471. cursor.execute(
  472. f"""
  473. SELECT name FROM {_schema_table_name(self.sqlite_version)} WHERE type = 'table' AND name NOT LIKE 'sqlite_%'
  474. """
  475. )
  476. for row in cursor.fetchall():
  477. cursor.execute(
  478. """
  479. SELECT COUNT() FROM %s
  480. """
  481. % row["name"],
  482. )
  483. usage[row["name"]] = {
  484. "rows": cursor.fetchone()[0],
  485. }
  486. return usage
  487. async def get_query_columns(self):
  488. columns = set()
  489. for name, typ, _ in UNIHASH_TABLE_DEFINITION + OUTHASH_TABLE_DEFINITION:
  490. if typ.startswith("TEXT"):
  491. columns.add(name)
  492. return list(columns)