server.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910
  1. # Copyright (C) 2019 Garmin Ltd.
  2. #
  3. # SPDX-License-Identifier: GPL-2.0-only
  4. #
  5. from datetime import datetime, timedelta
  6. import asyncio
  7. import logging
  8. import math
  9. import time
  10. import os
  11. import base64
  12. import json
  13. import hashlib
  14. from . import create_async_client
  15. import bb.asyncrpc
  16. logger = logging.getLogger("hashserv.server")
  17. # This permission only exists to match nothing
  18. NONE_PERM = "@none"
  19. READ_PERM = "@read"
  20. REPORT_PERM = "@report"
  21. DB_ADMIN_PERM = "@db-admin"
  22. USER_ADMIN_PERM = "@user-admin"
  23. ALL_PERM = "@all"
  24. ALL_PERMISSIONS = {
  25. READ_PERM,
  26. REPORT_PERM,
  27. DB_ADMIN_PERM,
  28. USER_ADMIN_PERM,
  29. ALL_PERM,
  30. }
  31. DEFAULT_ANON_PERMS = (
  32. READ_PERM,
  33. REPORT_PERM,
  34. DB_ADMIN_PERM,
  35. )
  36. TOKEN_ALGORITHM = "sha256"
  37. # 48 bytes of random data will result in 64 characters when base64
  38. # encoded. This number also ensures that the base64 encoding won't have any
  39. # trailing '=' characters.
  40. TOKEN_SIZE = 48
  41. SALT_SIZE = 8
  42. class Measurement(object):
  43. def __init__(self, sample):
  44. self.sample = sample
  45. def start(self):
  46. self.start_time = time.perf_counter()
  47. def end(self):
  48. self.sample.add(time.perf_counter() - self.start_time)
  49. def __enter__(self):
  50. self.start()
  51. return self
  52. def __exit__(self, *args, **kwargs):
  53. self.end()
  54. class Sample(object):
  55. def __init__(self, stats):
  56. self.stats = stats
  57. self.num_samples = 0
  58. self.elapsed = 0
  59. def measure(self):
  60. return Measurement(self)
  61. def __enter__(self):
  62. return self
  63. def __exit__(self, *args, **kwargs):
  64. self.end()
  65. def add(self, elapsed):
  66. self.num_samples += 1
  67. self.elapsed += elapsed
  68. def end(self):
  69. if self.num_samples:
  70. self.stats.add(self.elapsed)
  71. self.num_samples = 0
  72. self.elapsed = 0
  73. class Stats(object):
  74. def __init__(self):
  75. self.reset()
  76. def reset(self):
  77. self.num = 0
  78. self.total_time = 0
  79. self.max_time = 0
  80. self.m = 0
  81. self.s = 0
  82. self.current_elapsed = None
  83. def add(self, elapsed):
  84. self.num += 1
  85. if self.num == 1:
  86. self.m = elapsed
  87. self.s = 0
  88. else:
  89. last_m = self.m
  90. self.m = last_m + (elapsed - last_m) / self.num
  91. self.s = self.s + (elapsed - last_m) * (elapsed - self.m)
  92. self.total_time += elapsed
  93. if self.max_time < elapsed:
  94. self.max_time = elapsed
  95. def start_sample(self):
  96. return Sample(self)
  97. @property
  98. def average(self):
  99. if self.num == 0:
  100. return 0
  101. return self.total_time / self.num
  102. @property
  103. def stdev(self):
  104. if self.num <= 1:
  105. return 0
  106. return math.sqrt(self.s / (self.num - 1))
  107. def todict(self):
  108. return {
  109. k: getattr(self, k)
  110. for k in ("num", "total_time", "max_time", "average", "stdev")
  111. }
  112. token_refresh_semaphore = asyncio.Lock()
  113. async def new_token():
  114. # Prevent malicious users from using this API to deduce the entropy
  115. # pool on the server and thus be able to guess a token. *All* token
  116. # refresh requests lock the same global semaphore and then sleep for a
  117. # short time. The effectively rate limits the total number of requests
  118. # than can be made across all clients to 10/second, which should be enough
  119. # since you have to be an authenticated users to make the request in the
  120. # first place
  121. async with token_refresh_semaphore:
  122. await asyncio.sleep(0.1)
  123. raw = os.getrandom(TOKEN_SIZE, os.GRND_NONBLOCK)
  124. return base64.b64encode(raw, b"._").decode("utf-8")
  125. def new_salt():
  126. return os.getrandom(SALT_SIZE, os.GRND_NONBLOCK).hex()
  127. def hash_token(algo, salt, token):
  128. h = hashlib.new(algo)
  129. h.update(salt.encode("utf-8"))
  130. h.update(token.encode("utf-8"))
  131. return ":".join([algo, salt, h.hexdigest()])
  132. def permissions(*permissions, allow_anon=True, allow_self_service=False):
  133. """
  134. Function decorator that can be used to decorate an RPC function call and
  135. check that the current users permissions match the require permissions.
  136. If allow_anon is True, the user will also be allowed to make the RPC call
  137. if the anonymous user permissions match the permissions.
  138. If allow_self_service is True, and the "username" property in the request
  139. is the currently logged in user, or not specified, the user will also be
  140. allowed to make the request. This allows users to access normal privileged
  141. API, as long as they are only modifying their own user properties (e.g.
  142. users can be allowed to reset their own token without @user-admin
  143. permissions, but not the token for any other user.
  144. """
  145. def wrapper(func):
  146. async def wrap(self, request):
  147. if allow_self_service and self.user is not None:
  148. username = request.get("username", self.user.username)
  149. if username == self.user.username:
  150. request["username"] = self.user.username
  151. return await func(self, request)
  152. if not self.user_has_permissions(*permissions, allow_anon=allow_anon):
  153. if not self.user:
  154. username = "Anonymous user"
  155. user_perms = self.server.anon_perms
  156. else:
  157. username = self.user.username
  158. user_perms = self.user.permissions
  159. self.logger.info(
  160. "User %s with permissions %r denied from calling %s. Missing permissions(s) %r",
  161. username,
  162. ", ".join(user_perms),
  163. func.__name__,
  164. ", ".join(permissions),
  165. )
  166. raise bb.asyncrpc.InvokeError(
  167. f"{username} is not allowed to access permissions(s) {', '.join(permissions)}"
  168. )
  169. return await func(self, request)
  170. return wrap
  171. return wrapper
  172. class ServerClient(bb.asyncrpc.AsyncServerConnection):
  173. def __init__(self, socket, server):
  174. super().__init__(socket, "OEHASHEQUIV", server.logger)
  175. self.server = server
  176. self.max_chunk = bb.asyncrpc.DEFAULT_MAX_CHUNK
  177. self.user = None
  178. self.handlers.update(
  179. {
  180. "get": self.handle_get,
  181. "get-outhash": self.handle_get_outhash,
  182. "get-stream": self.handle_get_stream,
  183. "exists-stream": self.handle_exists_stream,
  184. "get-stats": self.handle_get_stats,
  185. "get-db-usage": self.handle_get_db_usage,
  186. "get-db-query-columns": self.handle_get_db_query_columns,
  187. # Not always read-only, but internally checks if the server is
  188. # read-only
  189. "report": self.handle_report,
  190. "auth": self.handle_auth,
  191. "get-user": self.handle_get_user,
  192. "get-all-users": self.handle_get_all_users,
  193. "become-user": self.handle_become_user,
  194. }
  195. )
  196. if not self.server.read_only:
  197. self.handlers.update(
  198. {
  199. "report-equiv": self.handle_equivreport,
  200. "reset-stats": self.handle_reset_stats,
  201. "backfill-wait": self.handle_backfill_wait,
  202. "remove": self.handle_remove,
  203. "gc-mark": self.handle_gc_mark,
  204. "gc-mark-stream": self.handle_gc_mark_stream,
  205. "gc-sweep": self.handle_gc_sweep,
  206. "gc-status": self.handle_gc_status,
  207. "clean-unused": self.handle_clean_unused,
  208. "refresh-token": self.handle_refresh_token,
  209. "set-user-perms": self.handle_set_perms,
  210. "new-user": self.handle_new_user,
  211. "delete-user": self.handle_delete_user,
  212. }
  213. )
  214. def raise_no_user_error(self, username):
  215. raise bb.asyncrpc.InvokeError(f"No user named '{username}' exists")
  216. def user_has_permissions(self, *permissions, allow_anon=True):
  217. permissions = set(permissions)
  218. if allow_anon:
  219. if ALL_PERM in self.server.anon_perms:
  220. return True
  221. if not permissions - self.server.anon_perms:
  222. return True
  223. if self.user is None:
  224. return False
  225. if ALL_PERM in self.user.permissions:
  226. return True
  227. if not permissions - self.user.permissions:
  228. return True
  229. return False
  230. def validate_proto_version(self):
  231. return self.proto_version > (1, 0) and self.proto_version <= (1, 1)
  232. async def process_requests(self):
  233. async with self.server.db_engine.connect(self.logger) as db:
  234. self.db = db
  235. if self.server.upstream is not None:
  236. self.upstream_client = await create_async_client(self.server.upstream)
  237. else:
  238. self.upstream_client = None
  239. try:
  240. await super().process_requests()
  241. finally:
  242. if self.upstream_client is not None:
  243. await self.upstream_client.close()
  244. async def dispatch_message(self, msg):
  245. for k in self.handlers.keys():
  246. if k in msg:
  247. self.logger.debug("Handling %s" % k)
  248. if "stream" in k:
  249. return await self.handlers[k](msg[k])
  250. else:
  251. with self.server.request_stats.start_sample() as self.request_sample, self.request_sample.measure():
  252. return await self.handlers[k](msg[k])
  253. raise bb.asyncrpc.ClientError("Unrecognized command %r" % msg)
  254. @permissions(READ_PERM)
  255. async def handle_get(self, request):
  256. method = request["method"]
  257. taskhash = request["taskhash"]
  258. fetch_all = request.get("all", False)
  259. return await self.get_unihash(method, taskhash, fetch_all)
  260. async def get_unihash(self, method, taskhash, fetch_all=False):
  261. d = None
  262. if fetch_all:
  263. row = await self.db.get_unihash_by_taskhash_full(method, taskhash)
  264. if row is not None:
  265. d = {k: row[k] for k in row.keys()}
  266. elif self.upstream_client is not None:
  267. d = await self.upstream_client.get_taskhash(method, taskhash, True)
  268. await self.update_unified(d)
  269. else:
  270. row = await self.db.get_equivalent(method, taskhash)
  271. if row is not None:
  272. d = {k: row[k] for k in row.keys()}
  273. elif self.upstream_client is not None:
  274. d = await self.upstream_client.get_taskhash(method, taskhash)
  275. await self.db.insert_unihash(d["method"], d["taskhash"], d["unihash"])
  276. return d
  277. @permissions(READ_PERM)
  278. async def handle_get_outhash(self, request):
  279. method = request["method"]
  280. outhash = request["outhash"]
  281. taskhash = request["taskhash"]
  282. with_unihash = request.get("with_unihash", True)
  283. return await self.get_outhash(method, outhash, taskhash, with_unihash)
  284. async def get_outhash(self, method, outhash, taskhash, with_unihash=True):
  285. d = None
  286. if with_unihash:
  287. row = await self.db.get_unihash_by_outhash(method, outhash)
  288. else:
  289. row = await self.db.get_outhash(method, outhash)
  290. if row is not None:
  291. d = {k: row[k] for k in row.keys()}
  292. elif self.upstream_client is not None:
  293. d = await self.upstream_client.get_outhash(method, outhash, taskhash)
  294. await self.update_unified(d)
  295. return d
  296. async def update_unified(self, data):
  297. if data is None:
  298. return
  299. await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"])
  300. await self.db.insert_outhash(data)
  301. async def _stream_handler(self, handler):
  302. await self.socket.send_message("ok")
  303. while True:
  304. upstream = None
  305. l = await self.socket.recv()
  306. if not l:
  307. break
  308. try:
  309. # This inner loop is very sensitive and must be as fast as
  310. # possible (which is why the request sample is handled manually
  311. # instead of using 'with', and also why logging statements are
  312. # commented out.
  313. self.request_sample = self.server.request_stats.start_sample()
  314. request_measure = self.request_sample.measure()
  315. request_measure.start()
  316. if l == "END":
  317. break
  318. msg = await handler(l)
  319. await self.socket.send(msg)
  320. finally:
  321. request_measure.end()
  322. self.request_sample.end()
  323. await self.socket.send("ok")
  324. return self.NO_RESPONSE
  325. @permissions(READ_PERM)
  326. async def handle_get_stream(self, request):
  327. async def handler(l):
  328. (method, taskhash) = l.split()
  329. # self.logger.debug('Looking up %s %s' % (method, taskhash))
  330. row = await self.db.get_equivalent(method, taskhash)
  331. if row is not None:
  332. # self.logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
  333. return row["unihash"]
  334. if self.upstream_client is not None:
  335. upstream = await self.upstream_client.get_unihash(method, taskhash)
  336. if upstream:
  337. await self.server.backfill_queue.put((method, taskhash))
  338. return upstream
  339. return ""
  340. return await self._stream_handler(handler)
  341. @permissions(READ_PERM)
  342. async def handle_exists_stream(self, request):
  343. async def handler(l):
  344. if await self.db.unihash_exists(l):
  345. return "true"
  346. if self.upstream_client is not None:
  347. if await self.upstream_client.unihash_exists(l):
  348. return "true"
  349. return "false"
  350. return await self._stream_handler(handler)
  351. async def report_readonly(self, data):
  352. method = data["method"]
  353. outhash = data["outhash"]
  354. taskhash = data["taskhash"]
  355. info = await self.get_outhash(method, outhash, taskhash)
  356. if info:
  357. unihash = info["unihash"]
  358. else:
  359. unihash = data["unihash"]
  360. return {
  361. "taskhash": taskhash,
  362. "method": method,
  363. "unihash": unihash,
  364. }
  365. # Since this can be called either read only or to report, the check to
  366. # report is made inside the function
  367. @permissions(READ_PERM)
  368. async def handle_report(self, data):
  369. if self.server.read_only or not self.user_has_permissions(REPORT_PERM):
  370. return await self.report_readonly(data)
  371. outhash_data = {
  372. "method": data["method"],
  373. "outhash": data["outhash"],
  374. "taskhash": data["taskhash"],
  375. "created": datetime.now(),
  376. }
  377. for k in ("owner", "PN", "PV", "PR", "task", "outhash_siginfo"):
  378. if k in data:
  379. outhash_data[k] = data[k]
  380. if self.user:
  381. outhash_data["owner"] = self.user.username
  382. # Insert the new entry, unless it already exists
  383. if await self.db.insert_outhash(outhash_data):
  384. # If this row is new, check if it is equivalent to another
  385. # output hash
  386. row = await self.db.get_equivalent_for_outhash(
  387. data["method"], data["outhash"], data["taskhash"]
  388. )
  389. if row is not None:
  390. # A matching output hash was found. Set our taskhash to the
  391. # same unihash since they are equivalent
  392. unihash = row["unihash"]
  393. else:
  394. # No matching output hash was found. This is probably the
  395. # first outhash to be added.
  396. unihash = data["unihash"]
  397. # Query upstream to see if it has a unihash we can use
  398. if self.upstream_client is not None:
  399. upstream_data = await self.upstream_client.get_outhash(
  400. data["method"], data["outhash"], data["taskhash"]
  401. )
  402. if upstream_data is not None:
  403. unihash = upstream_data["unihash"]
  404. await self.db.insert_unihash(data["method"], data["taskhash"], unihash)
  405. unihash_data = await self.get_unihash(data["method"], data["taskhash"])
  406. if unihash_data is not None:
  407. unihash = unihash_data["unihash"]
  408. else:
  409. unihash = data["unihash"]
  410. return {
  411. "taskhash": data["taskhash"],
  412. "method": data["method"],
  413. "unihash": unihash,
  414. }
  415. @permissions(READ_PERM, REPORT_PERM)
  416. async def handle_equivreport(self, data):
  417. await self.db.insert_unihash(data["method"], data["taskhash"], data["unihash"])
  418. # Fetch the unihash that will be reported for the taskhash. If the
  419. # unihash matches, it means this row was inserted (or the mapping
  420. # was already valid)
  421. row = await self.db.get_equivalent(data["method"], data["taskhash"])
  422. if row["unihash"] == data["unihash"]:
  423. self.logger.info(
  424. "Adding taskhash equivalence for %s with unihash %s",
  425. data["taskhash"],
  426. row["unihash"],
  427. )
  428. return {k: row[k] for k in ("taskhash", "method", "unihash")}
  429. @permissions(READ_PERM)
  430. async def handle_get_stats(self, request):
  431. return {
  432. "requests": self.server.request_stats.todict(),
  433. }
  434. @permissions(DB_ADMIN_PERM)
  435. async def handle_reset_stats(self, request):
  436. d = {
  437. "requests": self.server.request_stats.todict(),
  438. }
  439. self.server.request_stats.reset()
  440. return d
  441. @permissions(READ_PERM)
  442. async def handle_backfill_wait(self, request):
  443. d = {
  444. "tasks": self.server.backfill_queue.qsize(),
  445. }
  446. await self.server.backfill_queue.join()
  447. return d
  448. @permissions(DB_ADMIN_PERM)
  449. async def handle_remove(self, request):
  450. condition = request["where"]
  451. if not isinstance(condition, dict):
  452. raise TypeError("Bad condition type %s" % type(condition))
  453. return {"count": await self.db.remove(condition)}
  454. @permissions(DB_ADMIN_PERM)
  455. async def handle_gc_mark(self, request):
  456. condition = request["where"]
  457. mark = request["mark"]
  458. if not isinstance(condition, dict):
  459. raise TypeError("Bad condition type %s" % type(condition))
  460. if not isinstance(mark, str):
  461. raise TypeError("Bad mark type %s" % type(mark))
  462. return {"count": await self.db.gc_mark(mark, condition)}
  463. @permissions(DB_ADMIN_PERM)
  464. async def handle_gc_mark_stream(self, request):
  465. async def handler(line):
  466. try:
  467. decoded_line = json.loads(line)
  468. except json.JSONDecodeError as exc:
  469. raise bb.asyncrpc.InvokeError(
  470. "Could not decode JSONL input '%s'" % line
  471. ) from exc
  472. try:
  473. mark = decoded_line["mark"]
  474. condition = decoded_line["where"]
  475. if not isinstance(mark, str):
  476. raise TypeError("Bad mark type %s" % type(mark))
  477. if not isinstance(condition, dict):
  478. raise TypeError("Bad condition type %s" % type(condition))
  479. except KeyError as exc:
  480. raise bb.asyncrpc.InvokeError(
  481. "Input line is missing key '%s' " % exc
  482. ) from exc
  483. return json.dumps({"count": await self.db.gc_mark(mark, condition)})
  484. return await self._stream_handler(handler)
  485. @permissions(DB_ADMIN_PERM)
  486. async def handle_gc_sweep(self, request):
  487. mark = request["mark"]
  488. if not isinstance(mark, str):
  489. raise TypeError("Bad mark type %s" % type(mark))
  490. current_mark = await self.db.get_current_gc_mark()
  491. if not current_mark or mark != current_mark:
  492. raise bb.asyncrpc.InvokeError(
  493. f"'{mark}' is not the current mark. Refusing to sweep"
  494. )
  495. count = await self.db.gc_sweep()
  496. return {"count": count}
  497. @permissions(DB_ADMIN_PERM)
  498. async def handle_gc_status(self, request):
  499. (keep_rows, remove_rows, current_mark) = await self.db.gc_status()
  500. return {
  501. "keep": keep_rows,
  502. "remove": remove_rows,
  503. "mark": current_mark,
  504. }
  505. @permissions(DB_ADMIN_PERM)
  506. async def handle_clean_unused(self, request):
  507. max_age = request["max_age_seconds"]
  508. oldest = datetime.now() - timedelta(seconds=-max_age)
  509. return {"count": await self.db.clean_unused(oldest)}
  510. @permissions(DB_ADMIN_PERM)
  511. async def handle_get_db_usage(self, request):
  512. return {"usage": await self.db.get_usage()}
  513. @permissions(DB_ADMIN_PERM)
  514. async def handle_get_db_query_columns(self, request):
  515. return {"columns": await self.db.get_query_columns()}
  516. # The authentication API is always allowed
  517. async def handle_auth(self, request):
  518. username = str(request["username"])
  519. token = str(request["token"])
  520. async def fail_auth():
  521. nonlocal username
  522. # Rate limit bad login attempts
  523. await asyncio.sleep(1)
  524. raise bb.asyncrpc.InvokeError(f"Unable to authenticate as {username}")
  525. user, db_token = await self.db.lookup_user_token(username)
  526. if not user or not db_token:
  527. await fail_auth()
  528. try:
  529. algo, salt, _ = db_token.split(":")
  530. except ValueError:
  531. await fail_auth()
  532. if hash_token(algo, salt, token) != db_token:
  533. await fail_auth()
  534. self.user = user
  535. self.logger.info("Authenticated as %s", username)
  536. return {
  537. "result": True,
  538. "username": self.user.username,
  539. "permissions": sorted(list(self.user.permissions)),
  540. }
  541. @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False)
  542. async def handle_refresh_token(self, request):
  543. username = str(request["username"])
  544. token = await new_token()
  545. updated = await self.db.set_user_token(
  546. username,
  547. hash_token(TOKEN_ALGORITHM, new_salt(), token),
  548. )
  549. if not updated:
  550. self.raise_no_user_error(username)
  551. return {"username": username, "token": token}
  552. def get_perm_arg(self, arg):
  553. if not isinstance(arg, list):
  554. raise bb.asyncrpc.InvokeError("Unexpected type for permissions")
  555. arg = set(arg)
  556. try:
  557. arg.remove(NONE_PERM)
  558. except KeyError:
  559. pass
  560. unknown_perms = arg - ALL_PERMISSIONS
  561. if unknown_perms:
  562. raise bb.asyncrpc.InvokeError(
  563. "Unknown permissions %s" % ", ".join(sorted(list(unknown_perms)))
  564. )
  565. return sorted(list(arg))
  566. def return_perms(self, permissions):
  567. if ALL_PERM in permissions:
  568. return sorted(list(ALL_PERMISSIONS))
  569. return sorted(list(permissions))
  570. @permissions(USER_ADMIN_PERM, allow_anon=False)
  571. async def handle_set_perms(self, request):
  572. username = str(request["username"])
  573. permissions = self.get_perm_arg(request["permissions"])
  574. if not await self.db.set_user_perms(username, permissions):
  575. self.raise_no_user_error(username)
  576. return {
  577. "username": username,
  578. "permissions": self.return_perms(permissions),
  579. }
  580. @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False)
  581. async def handle_get_user(self, request):
  582. username = str(request["username"])
  583. user = await self.db.lookup_user(username)
  584. if user is None:
  585. return None
  586. return {
  587. "username": user.username,
  588. "permissions": self.return_perms(user.permissions),
  589. }
  590. @permissions(USER_ADMIN_PERM, allow_anon=False)
  591. async def handle_get_all_users(self, request):
  592. users = await self.db.get_all_users()
  593. return {
  594. "users": [
  595. {
  596. "username": u.username,
  597. "permissions": self.return_perms(u.permissions),
  598. }
  599. for u in users
  600. ]
  601. }
  602. @permissions(USER_ADMIN_PERM, allow_anon=False)
  603. async def handle_new_user(self, request):
  604. username = str(request["username"])
  605. permissions = self.get_perm_arg(request["permissions"])
  606. token = await new_token()
  607. inserted = await self.db.new_user(
  608. username,
  609. permissions,
  610. hash_token(TOKEN_ALGORITHM, new_salt(), token),
  611. )
  612. if not inserted:
  613. raise bb.asyncrpc.InvokeError(f"Cannot create new user '{username}'")
  614. return {
  615. "username": username,
  616. "permissions": self.return_perms(permissions),
  617. "token": token,
  618. }
  619. @permissions(USER_ADMIN_PERM, allow_self_service=True, allow_anon=False)
  620. async def handle_delete_user(self, request):
  621. username = str(request["username"])
  622. if not await self.db.delete_user(username):
  623. self.raise_no_user_error(username)
  624. return {"username": username}
  625. @permissions(USER_ADMIN_PERM, allow_anon=False)
  626. async def handle_become_user(self, request):
  627. username = str(request["username"])
  628. user = await self.db.lookup_user(username)
  629. if user is None:
  630. raise bb.asyncrpc.InvokeError(f"User {username} doesn't exist")
  631. self.user = user
  632. self.logger.info("Became user %s", username)
  633. return {
  634. "username": self.user.username,
  635. "permissions": self.return_perms(self.user.permissions),
  636. }
  637. class Server(bb.asyncrpc.AsyncServer):
  638. def __init__(
  639. self,
  640. db_engine,
  641. upstream=None,
  642. read_only=False,
  643. anon_perms=DEFAULT_ANON_PERMS,
  644. admin_username=None,
  645. admin_password=None,
  646. ):
  647. if upstream and read_only:
  648. raise bb.asyncrpc.ServerError(
  649. "Read-only hashserv cannot pull from an upstream server"
  650. )
  651. disallowed_perms = set(anon_perms) - set(
  652. [NONE_PERM, READ_PERM, REPORT_PERM, DB_ADMIN_PERM]
  653. )
  654. if disallowed_perms:
  655. raise bb.asyncrpc.ServerError(
  656. f"Permission(s) {' '.join(disallowed_perms)} are not allowed for anonymous users"
  657. )
  658. super().__init__(logger)
  659. self.request_stats = Stats()
  660. self.db_engine = db_engine
  661. self.upstream = upstream
  662. self.read_only = read_only
  663. self.backfill_queue = None
  664. self.anon_perms = set(anon_perms)
  665. self.admin_username = admin_username
  666. self.admin_password = admin_password
  667. self.logger.info(
  668. "Anonymous user permissions are: %s", ", ".join(self.anon_perms)
  669. )
  670. def accept_client(self, socket):
  671. return ServerClient(socket, self)
  672. async def create_admin_user(self):
  673. admin_permissions = (ALL_PERM,)
  674. async with self.db_engine.connect(self.logger) as db:
  675. added = await db.new_user(
  676. self.admin_username,
  677. admin_permissions,
  678. hash_token(TOKEN_ALGORITHM, new_salt(), self.admin_password),
  679. )
  680. if added:
  681. self.logger.info("Created admin user '%s'", self.admin_username)
  682. else:
  683. await db.set_user_perms(
  684. self.admin_username,
  685. admin_permissions,
  686. )
  687. await db.set_user_token(
  688. self.admin_username,
  689. hash_token(TOKEN_ALGORITHM, new_salt(), self.admin_password),
  690. )
  691. self.logger.info("Admin user '%s' updated", self.admin_username)
  692. async def backfill_worker_task(self):
  693. async with await create_async_client(
  694. self.upstream
  695. ) as client, self.db_engine.connect(self.logger) as db:
  696. while True:
  697. item = await self.backfill_queue.get()
  698. if item is None:
  699. self.backfill_queue.task_done()
  700. break
  701. method, taskhash = item
  702. d = await client.get_taskhash(method, taskhash)
  703. if d is not None:
  704. await db.insert_unihash(d["method"], d["taskhash"], d["unihash"])
  705. self.backfill_queue.task_done()
  706. def start(self):
  707. tasks = super().start()
  708. if self.upstream:
  709. self.backfill_queue = asyncio.Queue()
  710. tasks += [self.backfill_worker_task()]
  711. self.loop.run_until_complete(self.db_engine.create())
  712. if self.admin_username:
  713. self.loop.run_until_complete(self.create_admin_user())
  714. return tasks
  715. async def stop(self):
  716. if self.backfill_queue is not None:
  717. await self.backfill_queue.put(None)
  718. await super().stop()