123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598 |
- #! /usr/bin/env python3
- #
- # Copyright (C) 2023 Garmin Ltd.
- #
- # SPDX-License-Identifier: GPL-2.0-only
- #
- import logging
- from datetime import datetime
- from . import User
- from sqlalchemy.ext.asyncio import create_async_engine
- from sqlalchemy.pool import NullPool
- from sqlalchemy import (
- MetaData,
- Column,
- Table,
- Text,
- Integer,
- UniqueConstraint,
- DateTime,
- Index,
- select,
- insert,
- exists,
- literal,
- and_,
- delete,
- update,
- func,
- inspect,
- )
- import sqlalchemy.engine
- from sqlalchemy.orm import declarative_base
- from sqlalchemy.exc import IntegrityError
- from sqlalchemy.dialects.postgresql import insert as postgres_insert
- Base = declarative_base()
- class UnihashesV3(Base):
- __tablename__ = "unihashes_v3"
- id = Column(Integer, primary_key=True, autoincrement=True)
- method = Column(Text, nullable=False)
- taskhash = Column(Text, nullable=False)
- unihash = Column(Text, nullable=False)
- gc_mark = Column(Text, nullable=False)
- __table_args__ = (
- UniqueConstraint("method", "taskhash"),
- Index("taskhash_lookup_v4", "method", "taskhash"),
- Index("unihash_lookup_v1", "unihash"),
- )
- class OuthashesV2(Base):
- __tablename__ = "outhashes_v2"
- id = Column(Integer, primary_key=True, autoincrement=True)
- method = Column(Text, nullable=False)
- taskhash = Column(Text, nullable=False)
- outhash = Column(Text, nullable=False)
- created = Column(DateTime)
- owner = Column(Text)
- PN = Column(Text)
- PV = Column(Text)
- PR = Column(Text)
- task = Column(Text)
- outhash_siginfo = Column(Text)
- __table_args__ = (
- UniqueConstraint("method", "taskhash", "outhash"),
- Index("outhash_lookup_v3", "method", "outhash"),
- )
- class Users(Base):
- __tablename__ = "users"
- id = Column(Integer, primary_key=True, autoincrement=True)
- username = Column(Text, nullable=False)
- token = Column(Text, nullable=False)
- permissions = Column(Text)
- __table_args__ = (UniqueConstraint("username"),)
- class Config(Base):
- __tablename__ = "config"
- id = Column(Integer, primary_key=True, autoincrement=True)
- name = Column(Text, nullable=False)
- value = Column(Text)
- __table_args__ = (
- UniqueConstraint("name"),
- Index("config_lookup", "name"),
- )
- #
- # Old table versions
- #
- DeprecatedBase = declarative_base()
- class UnihashesV2(DeprecatedBase):
- __tablename__ = "unihashes_v2"
- id = Column(Integer, primary_key=True, autoincrement=True)
- method = Column(Text, nullable=False)
- taskhash = Column(Text, nullable=False)
- unihash = Column(Text, nullable=False)
- __table_args__ = (
- UniqueConstraint("method", "taskhash"),
- Index("taskhash_lookup_v3", "method", "taskhash"),
- )
- class DatabaseEngine(object):
- def __init__(self, url, username=None, password=None):
- self.logger = logging.getLogger("hashserv.sqlalchemy")
- self.url = sqlalchemy.engine.make_url(url)
- if username is not None:
- self.url = self.url.set(username=username)
- if password is not None:
- self.url = self.url.set(password=password)
- async def create(self):
- def check_table_exists(conn, name):
- return inspect(conn).has_table(name)
- self.logger.info("Using database %s", self.url)
- if self.url.drivername == 'postgresql+psycopg':
- # Psygopg 3 (psygopg) driver can handle async connection pooling
- self.engine = create_async_engine(self.url, max_overflow=-1)
- else:
- self.engine = create_async_engine(self.url, poolclass=NullPool)
- async with self.engine.begin() as conn:
- # Create tables
- self.logger.info("Creating tables...")
- await conn.run_sync(Base.metadata.create_all)
- if await conn.run_sync(check_table_exists, UnihashesV2.__tablename__):
- self.logger.info("Upgrading Unihashes V2 -> V3...")
- statement = insert(UnihashesV3).from_select(
- ["id", "method", "unihash", "taskhash", "gc_mark"],
- select(
- UnihashesV2.id,
- UnihashesV2.method,
- UnihashesV2.unihash,
- UnihashesV2.taskhash,
- literal("").label("gc_mark"),
- ),
- )
- self.logger.debug("%s", statement)
- await conn.execute(statement)
- await conn.run_sync(Base.metadata.drop_all, [UnihashesV2.__table__])
- self.logger.info("Upgrade complete")
- def connect(self, logger):
- return Database(self.engine, logger)
- def map_row(row):
- if row is None:
- return None
- return dict(**row._mapping)
- def map_user(row):
- if row is None:
- return None
- return User(
- username=row.username,
- permissions=set(row.permissions.split()),
- )
- def _make_condition_statement(table, condition):
- where = {}
- for c in table.__table__.columns:
- if c.key in condition and condition[c.key] is not None:
- where[c] = condition[c.key]
- return [(k == v) for k, v in where.items()]
- class Database(object):
- def __init__(self, engine, logger):
- self.engine = engine
- self.db = None
- self.logger = logger
- async def __aenter__(self):
- self.db = await self.engine.connect()
- return self
- async def __aexit__(self, exc_type, exc_value, traceback):
- await self.close()
- async def close(self):
- await self.db.close()
- self.db = None
- async def _execute(self, statement):
- self.logger.debug("%s", statement)
- return await self.db.execute(statement)
- async def _set_config(self, name, value):
- while True:
- result = await self._execute(
- update(Config).where(Config.name == name).values(value=value)
- )
- if result.rowcount == 0:
- self.logger.debug("Config '%s' not found. Adding it", name)
- try:
- await self._execute(insert(Config).values(name=name, value=value))
- except IntegrityError:
- # Race. Try again
- continue
- break
- def _get_config_subquery(self, name, default=None):
- if default is not None:
- return func.coalesce(
- select(Config.value).where(Config.name == name).scalar_subquery(),
- default,
- )
- return select(Config.value).where(Config.name == name).scalar_subquery()
- async def _get_config(self, name):
- result = await self._execute(select(Config.value).where(Config.name == name))
- row = result.first()
- if row is None:
- return None
- return row.value
- async def get_unihash_by_taskhash_full(self, method, taskhash):
- async with self.db.begin():
- result = await self._execute(
- select(
- OuthashesV2,
- UnihashesV3.unihash.label("unihash"),
- )
- .join(
- UnihashesV3,
- and_(
- UnihashesV3.method == OuthashesV2.method,
- UnihashesV3.taskhash == OuthashesV2.taskhash,
- ),
- )
- .where(
- OuthashesV2.method == method,
- OuthashesV2.taskhash == taskhash,
- )
- .order_by(
- OuthashesV2.created.asc(),
- )
- .limit(1)
- )
- return map_row(result.first())
- async def get_unihash_by_outhash(self, method, outhash):
- async with self.db.begin():
- result = await self._execute(
- select(OuthashesV2, UnihashesV3.unihash.label("unihash"))
- .join(
- UnihashesV3,
- and_(
- UnihashesV3.method == OuthashesV2.method,
- UnihashesV3.taskhash == OuthashesV2.taskhash,
- ),
- )
- .where(
- OuthashesV2.method == method,
- OuthashesV2.outhash == outhash,
- )
- .order_by(
- OuthashesV2.created.asc(),
- )
- .limit(1)
- )
- return map_row(result.first())
- async def unihash_exists(self, unihash):
- async with self.db.begin():
- result = await self._execute(
- select(UnihashesV3).where(UnihashesV3.unihash == unihash).limit(1)
- )
- return result.first() is not None
- async def get_outhash(self, method, outhash):
- async with self.db.begin():
- result = await self._execute(
- select(OuthashesV2)
- .where(
- OuthashesV2.method == method,
- OuthashesV2.outhash == outhash,
- )
- .order_by(
- OuthashesV2.created.asc(),
- )
- .limit(1)
- )
- return map_row(result.first())
- async def get_equivalent_for_outhash(self, method, outhash, taskhash):
- async with self.db.begin():
- result = await self._execute(
- select(
- OuthashesV2.taskhash.label("taskhash"),
- UnihashesV3.unihash.label("unihash"),
- )
- .join(
- UnihashesV3,
- and_(
- UnihashesV3.method == OuthashesV2.method,
- UnihashesV3.taskhash == OuthashesV2.taskhash,
- ),
- )
- .where(
- OuthashesV2.method == method,
- OuthashesV2.outhash == outhash,
- OuthashesV2.taskhash != taskhash,
- )
- .order_by(
- OuthashesV2.created.asc(),
- )
- .limit(1)
- )
- return map_row(result.first())
- async def get_equivalent(self, method, taskhash):
- async with self.db.begin():
- result = await self._execute(
- select(
- UnihashesV3.unihash,
- UnihashesV3.method,
- UnihashesV3.taskhash,
- ).where(
- UnihashesV3.method == method,
- UnihashesV3.taskhash == taskhash,
- )
- )
- return map_row(result.first())
- async def remove(self, condition):
- async def do_remove(table):
- where = _make_condition_statement(table, condition)
- if where:
- async with self.db.begin():
- result = await self._execute(delete(table).where(*where))
- return result.rowcount
- return 0
- count = 0
- count += await do_remove(UnihashesV3)
- count += await do_remove(OuthashesV2)
- return count
- async def get_current_gc_mark(self):
- async with self.db.begin():
- return await self._get_config("gc-mark")
- async def gc_status(self):
- async with self.db.begin():
- gc_mark_subquery = self._get_config_subquery("gc-mark", "")
- result = await self._execute(
- select(func.count())
- .select_from(UnihashesV3)
- .where(UnihashesV3.gc_mark == gc_mark_subquery)
- )
- keep_rows = result.scalar()
- result = await self._execute(
- select(func.count())
- .select_from(UnihashesV3)
- .where(UnihashesV3.gc_mark != gc_mark_subquery)
- )
- remove_rows = result.scalar()
- return (keep_rows, remove_rows, await self._get_config("gc-mark"))
- async def gc_mark(self, mark, condition):
- async with self.db.begin():
- await self._set_config("gc-mark", mark)
- where = _make_condition_statement(UnihashesV3, condition)
- if not where:
- return 0
- result = await self._execute(
- update(UnihashesV3)
- .values(gc_mark=self._get_config_subquery("gc-mark", ""))
- .where(*where)
- )
- return result.rowcount
- async def gc_sweep(self):
- async with self.db.begin():
- result = await self._execute(
- delete(UnihashesV3).where(
- # A sneaky conditional that provides some errant use
- # protection: If the config mark is NULL, this will not
- # match any rows because No default is specified in the
- # select statement
- UnihashesV3.gc_mark
- != self._get_config_subquery("gc-mark")
- )
- )
- await self._set_config("gc-mark", None)
- return result.rowcount
- async def clean_unused(self, oldest):
- async with self.db.begin():
- result = await self._execute(
- delete(OuthashesV2).where(
- OuthashesV2.created < oldest,
- ~(
- select(UnihashesV3.id)
- .where(
- UnihashesV3.method == OuthashesV2.method,
- UnihashesV3.taskhash == OuthashesV2.taskhash,
- )
- .limit(1)
- .exists()
- ),
- )
- )
- return result.rowcount
- async def insert_unihash(self, method, taskhash, unihash):
- # Postgres specific ignore on insert duplicate
- if self.engine.name == "postgresql":
- statement = (
- postgres_insert(UnihashesV3)
- .values(
- method=method,
- taskhash=taskhash,
- unihash=unihash,
- gc_mark=self._get_config_subquery("gc-mark", ""),
- )
- .on_conflict_do_nothing(index_elements=("method", "taskhash"))
- )
- else:
- statement = insert(UnihashesV3).values(
- method=method,
- taskhash=taskhash,
- unihash=unihash,
- gc_mark=self._get_config_subquery("gc-mark", ""),
- )
- try:
- async with self.db.begin():
- result = await self._execute(statement)
- return result.rowcount != 0
- except IntegrityError:
- self.logger.debug(
- "%s, %s, %s already in unihash database", method, taskhash, unihash
- )
- return False
- async def insert_outhash(self, data):
- outhash_columns = set(c.key for c in OuthashesV2.__table__.columns)
- data = {k: v for k, v in data.items() if k in outhash_columns}
- if "created" in data and not isinstance(data["created"], datetime):
- data["created"] = datetime.fromisoformat(data["created"])
- # Postgres specific ignore on insert duplicate
- if self.engine.name == "postgresql":
- statement = (
- postgres_insert(OuthashesV2)
- .values(**data)
- .on_conflict_do_nothing(
- index_elements=("method", "taskhash", "outhash")
- )
- )
- else:
- statement = insert(OuthashesV2).values(**data)
- try:
- async with self.db.begin():
- result = await self._execute(statement)
- return result.rowcount != 0
- except IntegrityError:
- self.logger.debug(
- "%s, %s already in outhash database", data["method"], data["outhash"]
- )
- return False
- async def _get_user(self, username):
- async with self.db.begin():
- result = await self._execute(
- select(
- Users.username,
- Users.permissions,
- Users.token,
- ).where(
- Users.username == username,
- )
- )
- return result.first()
- async def lookup_user_token(self, username):
- row = await self._get_user(username)
- if not row:
- return None, None
- return map_user(row), row.token
- async def lookup_user(self, username):
- return map_user(await self._get_user(username))
- async def set_user_token(self, username, token):
- async with self.db.begin():
- result = await self._execute(
- update(Users)
- .where(
- Users.username == username,
- )
- .values(
- token=token,
- )
- )
- return result.rowcount != 0
- async def set_user_perms(self, username, permissions):
- async with self.db.begin():
- result = await self._execute(
- update(Users)
- .where(Users.username == username)
- .values(permissions=" ".join(permissions))
- )
- return result.rowcount != 0
- async def get_all_users(self):
- async with self.db.begin():
- result = await self._execute(
- select(
- Users.username,
- Users.permissions,
- )
- )
- return [map_user(row) for row in result]
- async def new_user(self, username, permissions, token):
- try:
- async with self.db.begin():
- await self._execute(
- insert(Users).values(
- username=username,
- permissions=" ".join(permissions),
- token=token,
- )
- )
- return True
- except IntegrityError as e:
- self.logger.debug("Cannot create new user %s: %s", username, e)
- return False
- async def delete_user(self, username):
- async with self.db.begin():
- result = await self._execute(
- delete(Users).where(Users.username == username)
- )
- return result.rowcount != 0
- async def get_usage(self):
- usage = {}
- async with self.db.begin() as session:
- for name, table in Base.metadata.tables.items():
- result = await self._execute(
- statement=select(func.count()).select_from(table)
- )
- usage[name] = {
- "rows": result.scalar(),
- }
- return usage
- async def get_query_columns(self):
- columns = set()
- for table in (UnihashesV3, OuthashesV2):
- for c in table.__table__.columns:
- if not isinstance(c.type, Text):
- continue
- columns.add(c.key)
- return list(columns)
|