sqlalchemy.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598
  1. #! /usr/bin/env python3
  2. #
  3. # Copyright (C) 2023 Garmin Ltd.
  4. #
  5. # SPDX-License-Identifier: GPL-2.0-only
  6. #
  7. import logging
  8. from datetime import datetime
  9. from . import User
  10. from sqlalchemy.ext.asyncio import create_async_engine
  11. from sqlalchemy.pool import NullPool
  12. from sqlalchemy import (
  13. MetaData,
  14. Column,
  15. Table,
  16. Text,
  17. Integer,
  18. UniqueConstraint,
  19. DateTime,
  20. Index,
  21. select,
  22. insert,
  23. exists,
  24. literal,
  25. and_,
  26. delete,
  27. update,
  28. func,
  29. inspect,
  30. )
  31. import sqlalchemy.engine
  32. from sqlalchemy.orm import declarative_base
  33. from sqlalchemy.exc import IntegrityError
  34. from sqlalchemy.dialects.postgresql import insert as postgres_insert
  35. Base = declarative_base()
  36. class UnihashesV3(Base):
  37. __tablename__ = "unihashes_v3"
  38. id = Column(Integer, primary_key=True, autoincrement=True)
  39. method = Column(Text, nullable=False)
  40. taskhash = Column(Text, nullable=False)
  41. unihash = Column(Text, nullable=False)
  42. gc_mark = Column(Text, nullable=False)
  43. __table_args__ = (
  44. UniqueConstraint("method", "taskhash"),
  45. Index("taskhash_lookup_v4", "method", "taskhash"),
  46. Index("unihash_lookup_v1", "unihash"),
  47. )
  48. class OuthashesV2(Base):
  49. __tablename__ = "outhashes_v2"
  50. id = Column(Integer, primary_key=True, autoincrement=True)
  51. method = Column(Text, nullable=False)
  52. taskhash = Column(Text, nullable=False)
  53. outhash = Column(Text, nullable=False)
  54. created = Column(DateTime)
  55. owner = Column(Text)
  56. PN = Column(Text)
  57. PV = Column(Text)
  58. PR = Column(Text)
  59. task = Column(Text)
  60. outhash_siginfo = Column(Text)
  61. __table_args__ = (
  62. UniqueConstraint("method", "taskhash", "outhash"),
  63. Index("outhash_lookup_v3", "method", "outhash"),
  64. )
  65. class Users(Base):
  66. __tablename__ = "users"
  67. id = Column(Integer, primary_key=True, autoincrement=True)
  68. username = Column(Text, nullable=False)
  69. token = Column(Text, nullable=False)
  70. permissions = Column(Text)
  71. __table_args__ = (UniqueConstraint("username"),)
  72. class Config(Base):
  73. __tablename__ = "config"
  74. id = Column(Integer, primary_key=True, autoincrement=True)
  75. name = Column(Text, nullable=False)
  76. value = Column(Text)
  77. __table_args__ = (
  78. UniqueConstraint("name"),
  79. Index("config_lookup", "name"),
  80. )
  81. #
  82. # Old table versions
  83. #
  84. DeprecatedBase = declarative_base()
  85. class UnihashesV2(DeprecatedBase):
  86. __tablename__ = "unihashes_v2"
  87. id = Column(Integer, primary_key=True, autoincrement=True)
  88. method = Column(Text, nullable=False)
  89. taskhash = Column(Text, nullable=False)
  90. unihash = Column(Text, nullable=False)
  91. __table_args__ = (
  92. UniqueConstraint("method", "taskhash"),
  93. Index("taskhash_lookup_v3", "method", "taskhash"),
  94. )
  95. class DatabaseEngine(object):
  96. def __init__(self, url, username=None, password=None):
  97. self.logger = logging.getLogger("hashserv.sqlalchemy")
  98. self.url = sqlalchemy.engine.make_url(url)
  99. if username is not None:
  100. self.url = self.url.set(username=username)
  101. if password is not None:
  102. self.url = self.url.set(password=password)
  103. async def create(self):
  104. def check_table_exists(conn, name):
  105. return inspect(conn).has_table(name)
  106. self.logger.info("Using database %s", self.url)
  107. if self.url.drivername == 'postgresql+psycopg':
  108. # Psygopg 3 (psygopg) driver can handle async connection pooling
  109. self.engine = create_async_engine(self.url, max_overflow=-1)
  110. else:
  111. self.engine = create_async_engine(self.url, poolclass=NullPool)
  112. async with self.engine.begin() as conn:
  113. # Create tables
  114. self.logger.info("Creating tables...")
  115. await conn.run_sync(Base.metadata.create_all)
  116. if await conn.run_sync(check_table_exists, UnihashesV2.__tablename__):
  117. self.logger.info("Upgrading Unihashes V2 -> V3...")
  118. statement = insert(UnihashesV3).from_select(
  119. ["id", "method", "unihash", "taskhash", "gc_mark"],
  120. select(
  121. UnihashesV2.id,
  122. UnihashesV2.method,
  123. UnihashesV2.unihash,
  124. UnihashesV2.taskhash,
  125. literal("").label("gc_mark"),
  126. ),
  127. )
  128. self.logger.debug("%s", statement)
  129. await conn.execute(statement)
  130. await conn.run_sync(Base.metadata.drop_all, [UnihashesV2.__table__])
  131. self.logger.info("Upgrade complete")
  132. def connect(self, logger):
  133. return Database(self.engine, logger)
  134. def map_row(row):
  135. if row is None:
  136. return None
  137. return dict(**row._mapping)
  138. def map_user(row):
  139. if row is None:
  140. return None
  141. return User(
  142. username=row.username,
  143. permissions=set(row.permissions.split()),
  144. )
  145. def _make_condition_statement(table, condition):
  146. where = {}
  147. for c in table.__table__.columns:
  148. if c.key in condition and condition[c.key] is not None:
  149. where[c] = condition[c.key]
  150. return [(k == v) for k, v in where.items()]
  151. class Database(object):
  152. def __init__(self, engine, logger):
  153. self.engine = engine
  154. self.db = None
  155. self.logger = logger
  156. async def __aenter__(self):
  157. self.db = await self.engine.connect()
  158. return self
  159. async def __aexit__(self, exc_type, exc_value, traceback):
  160. await self.close()
  161. async def close(self):
  162. await self.db.close()
  163. self.db = None
  164. async def _execute(self, statement):
  165. self.logger.debug("%s", statement)
  166. return await self.db.execute(statement)
  167. async def _set_config(self, name, value):
  168. while True:
  169. result = await self._execute(
  170. update(Config).where(Config.name == name).values(value=value)
  171. )
  172. if result.rowcount == 0:
  173. self.logger.debug("Config '%s' not found. Adding it", name)
  174. try:
  175. await self._execute(insert(Config).values(name=name, value=value))
  176. except IntegrityError:
  177. # Race. Try again
  178. continue
  179. break
  180. def _get_config_subquery(self, name, default=None):
  181. if default is not None:
  182. return func.coalesce(
  183. select(Config.value).where(Config.name == name).scalar_subquery(),
  184. default,
  185. )
  186. return select(Config.value).where(Config.name == name).scalar_subquery()
  187. async def _get_config(self, name):
  188. result = await self._execute(select(Config.value).where(Config.name == name))
  189. row = result.first()
  190. if row is None:
  191. return None
  192. return row.value
  193. async def get_unihash_by_taskhash_full(self, method, taskhash):
  194. async with self.db.begin():
  195. result = await self._execute(
  196. select(
  197. OuthashesV2,
  198. UnihashesV3.unihash.label("unihash"),
  199. )
  200. .join(
  201. UnihashesV3,
  202. and_(
  203. UnihashesV3.method == OuthashesV2.method,
  204. UnihashesV3.taskhash == OuthashesV2.taskhash,
  205. ),
  206. )
  207. .where(
  208. OuthashesV2.method == method,
  209. OuthashesV2.taskhash == taskhash,
  210. )
  211. .order_by(
  212. OuthashesV2.created.asc(),
  213. )
  214. .limit(1)
  215. )
  216. return map_row(result.first())
  217. async def get_unihash_by_outhash(self, method, outhash):
  218. async with self.db.begin():
  219. result = await self._execute(
  220. select(OuthashesV2, UnihashesV3.unihash.label("unihash"))
  221. .join(
  222. UnihashesV3,
  223. and_(
  224. UnihashesV3.method == OuthashesV2.method,
  225. UnihashesV3.taskhash == OuthashesV2.taskhash,
  226. ),
  227. )
  228. .where(
  229. OuthashesV2.method == method,
  230. OuthashesV2.outhash == outhash,
  231. )
  232. .order_by(
  233. OuthashesV2.created.asc(),
  234. )
  235. .limit(1)
  236. )
  237. return map_row(result.first())
  238. async def unihash_exists(self, unihash):
  239. async with self.db.begin():
  240. result = await self._execute(
  241. select(UnihashesV3).where(UnihashesV3.unihash == unihash).limit(1)
  242. )
  243. return result.first() is not None
  244. async def get_outhash(self, method, outhash):
  245. async with self.db.begin():
  246. result = await self._execute(
  247. select(OuthashesV2)
  248. .where(
  249. OuthashesV2.method == method,
  250. OuthashesV2.outhash == outhash,
  251. )
  252. .order_by(
  253. OuthashesV2.created.asc(),
  254. )
  255. .limit(1)
  256. )
  257. return map_row(result.first())
  258. async def get_equivalent_for_outhash(self, method, outhash, taskhash):
  259. async with self.db.begin():
  260. result = await self._execute(
  261. select(
  262. OuthashesV2.taskhash.label("taskhash"),
  263. UnihashesV3.unihash.label("unihash"),
  264. )
  265. .join(
  266. UnihashesV3,
  267. and_(
  268. UnihashesV3.method == OuthashesV2.method,
  269. UnihashesV3.taskhash == OuthashesV2.taskhash,
  270. ),
  271. )
  272. .where(
  273. OuthashesV2.method == method,
  274. OuthashesV2.outhash == outhash,
  275. OuthashesV2.taskhash != taskhash,
  276. )
  277. .order_by(
  278. OuthashesV2.created.asc(),
  279. )
  280. .limit(1)
  281. )
  282. return map_row(result.first())
  283. async def get_equivalent(self, method, taskhash):
  284. async with self.db.begin():
  285. result = await self._execute(
  286. select(
  287. UnihashesV3.unihash,
  288. UnihashesV3.method,
  289. UnihashesV3.taskhash,
  290. ).where(
  291. UnihashesV3.method == method,
  292. UnihashesV3.taskhash == taskhash,
  293. )
  294. )
  295. return map_row(result.first())
  296. async def remove(self, condition):
  297. async def do_remove(table):
  298. where = _make_condition_statement(table, condition)
  299. if where:
  300. async with self.db.begin():
  301. result = await self._execute(delete(table).where(*where))
  302. return result.rowcount
  303. return 0
  304. count = 0
  305. count += await do_remove(UnihashesV3)
  306. count += await do_remove(OuthashesV2)
  307. return count
  308. async def get_current_gc_mark(self):
  309. async with self.db.begin():
  310. return await self._get_config("gc-mark")
  311. async def gc_status(self):
  312. async with self.db.begin():
  313. gc_mark_subquery = self._get_config_subquery("gc-mark", "")
  314. result = await self._execute(
  315. select(func.count())
  316. .select_from(UnihashesV3)
  317. .where(UnihashesV3.gc_mark == gc_mark_subquery)
  318. )
  319. keep_rows = result.scalar()
  320. result = await self._execute(
  321. select(func.count())
  322. .select_from(UnihashesV3)
  323. .where(UnihashesV3.gc_mark != gc_mark_subquery)
  324. )
  325. remove_rows = result.scalar()
  326. return (keep_rows, remove_rows, await self._get_config("gc-mark"))
  327. async def gc_mark(self, mark, condition):
  328. async with self.db.begin():
  329. await self._set_config("gc-mark", mark)
  330. where = _make_condition_statement(UnihashesV3, condition)
  331. if not where:
  332. return 0
  333. result = await self._execute(
  334. update(UnihashesV3)
  335. .values(gc_mark=self._get_config_subquery("gc-mark", ""))
  336. .where(*where)
  337. )
  338. return result.rowcount
  339. async def gc_sweep(self):
  340. async with self.db.begin():
  341. result = await self._execute(
  342. delete(UnihashesV3).where(
  343. # A sneaky conditional that provides some errant use
  344. # protection: If the config mark is NULL, this will not
  345. # match any rows because No default is specified in the
  346. # select statement
  347. UnihashesV3.gc_mark
  348. != self._get_config_subquery("gc-mark")
  349. )
  350. )
  351. await self._set_config("gc-mark", None)
  352. return result.rowcount
  353. async def clean_unused(self, oldest):
  354. async with self.db.begin():
  355. result = await self._execute(
  356. delete(OuthashesV2).where(
  357. OuthashesV2.created < oldest,
  358. ~(
  359. select(UnihashesV3.id)
  360. .where(
  361. UnihashesV3.method == OuthashesV2.method,
  362. UnihashesV3.taskhash == OuthashesV2.taskhash,
  363. )
  364. .limit(1)
  365. .exists()
  366. ),
  367. )
  368. )
  369. return result.rowcount
  370. async def insert_unihash(self, method, taskhash, unihash):
  371. # Postgres specific ignore on insert duplicate
  372. if self.engine.name == "postgresql":
  373. statement = (
  374. postgres_insert(UnihashesV3)
  375. .values(
  376. method=method,
  377. taskhash=taskhash,
  378. unihash=unihash,
  379. gc_mark=self._get_config_subquery("gc-mark", ""),
  380. )
  381. .on_conflict_do_nothing(index_elements=("method", "taskhash"))
  382. )
  383. else:
  384. statement = insert(UnihashesV3).values(
  385. method=method,
  386. taskhash=taskhash,
  387. unihash=unihash,
  388. gc_mark=self._get_config_subquery("gc-mark", ""),
  389. )
  390. try:
  391. async with self.db.begin():
  392. result = await self._execute(statement)
  393. return result.rowcount != 0
  394. except IntegrityError:
  395. self.logger.debug(
  396. "%s, %s, %s already in unihash database", method, taskhash, unihash
  397. )
  398. return False
  399. async def insert_outhash(self, data):
  400. outhash_columns = set(c.key for c in OuthashesV2.__table__.columns)
  401. data = {k: v for k, v in data.items() if k in outhash_columns}
  402. if "created" in data and not isinstance(data["created"], datetime):
  403. data["created"] = datetime.fromisoformat(data["created"])
  404. # Postgres specific ignore on insert duplicate
  405. if self.engine.name == "postgresql":
  406. statement = (
  407. postgres_insert(OuthashesV2)
  408. .values(**data)
  409. .on_conflict_do_nothing(
  410. index_elements=("method", "taskhash", "outhash")
  411. )
  412. )
  413. else:
  414. statement = insert(OuthashesV2).values(**data)
  415. try:
  416. async with self.db.begin():
  417. result = await self._execute(statement)
  418. return result.rowcount != 0
  419. except IntegrityError:
  420. self.logger.debug(
  421. "%s, %s already in outhash database", data["method"], data["outhash"]
  422. )
  423. return False
  424. async def _get_user(self, username):
  425. async with self.db.begin():
  426. result = await self._execute(
  427. select(
  428. Users.username,
  429. Users.permissions,
  430. Users.token,
  431. ).where(
  432. Users.username == username,
  433. )
  434. )
  435. return result.first()
  436. async def lookup_user_token(self, username):
  437. row = await self._get_user(username)
  438. if not row:
  439. return None, None
  440. return map_user(row), row.token
  441. async def lookup_user(self, username):
  442. return map_user(await self._get_user(username))
  443. async def set_user_token(self, username, token):
  444. async with self.db.begin():
  445. result = await self._execute(
  446. update(Users)
  447. .where(
  448. Users.username == username,
  449. )
  450. .values(
  451. token=token,
  452. )
  453. )
  454. return result.rowcount != 0
  455. async def set_user_perms(self, username, permissions):
  456. async with self.db.begin():
  457. result = await self._execute(
  458. update(Users)
  459. .where(Users.username == username)
  460. .values(permissions=" ".join(permissions))
  461. )
  462. return result.rowcount != 0
  463. async def get_all_users(self):
  464. async with self.db.begin():
  465. result = await self._execute(
  466. select(
  467. Users.username,
  468. Users.permissions,
  469. )
  470. )
  471. return [map_user(row) for row in result]
  472. async def new_user(self, username, permissions, token):
  473. try:
  474. async with self.db.begin():
  475. await self._execute(
  476. insert(Users).values(
  477. username=username,
  478. permissions=" ".join(permissions),
  479. token=token,
  480. )
  481. )
  482. return True
  483. except IntegrityError as e:
  484. self.logger.debug("Cannot create new user %s: %s", username, e)
  485. return False
  486. async def delete_user(self, username):
  487. async with self.db.begin():
  488. result = await self._execute(
  489. delete(Users).where(Users.username == username)
  490. )
  491. return result.rowcount != 0
  492. async def get_usage(self):
  493. usage = {}
  494. async with self.db.begin() as session:
  495. for name, table in Base.metadata.tables.items():
  496. result = await self._execute(
  497. statement=select(func.count()).select_from(table)
  498. )
  499. usage[name] = {
  500. "rows": result.scalar(),
  501. }
  502. return usage
  503. async def get_query_columns(self):
  504. columns = set()
  505. for table in (UnihashesV3, OuthashesV2):
  506. for c in table.__table__.columns:
  507. if not isinstance(c.type, Text):
  508. continue
  509. columns.add(c.key)
  510. return list(columns)