client.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. # Copyright (C) 2019 Garmin Ltd.
  2. #
  3. # SPDX-License-Identifier: GPL-2.0-only
  4. #
  5. import logging
  6. import socket
  7. import asyncio
  8. import bb.asyncrpc
  9. import json
  10. from . import create_async_client
  11. logger = logging.getLogger("hashserv.client")
  12. class Batch(object):
  13. def __init__(self):
  14. self.done = False
  15. self.cond = asyncio.Condition()
  16. self.pending = []
  17. self.results = []
  18. self.sent_count = 0
  19. async def recv(self, socket):
  20. while True:
  21. async with self.cond:
  22. await self.cond.wait_for(lambda: self.pending or self.done)
  23. if not self.pending:
  24. if self.done:
  25. return
  26. continue
  27. r = await socket.recv()
  28. self.results.append(r)
  29. async with self.cond:
  30. self.pending.pop(0)
  31. async def send(self, socket, msgs):
  32. try:
  33. # In the event of a restart due to a reconnect, all in-flight
  34. # messages need to be resent first to keep to result count in sync
  35. for m in self.pending:
  36. await socket.send(m)
  37. for m in msgs:
  38. # Add the message to the pending list before attempting to send
  39. # it so that if the send fails it will be retried
  40. async with self.cond:
  41. self.pending.append(m)
  42. self.cond.notify()
  43. self.sent_count += 1
  44. await socket.send(m)
  45. finally:
  46. async with self.cond:
  47. self.done = True
  48. self.cond.notify()
  49. async def process(self, socket, msgs):
  50. await asyncio.gather(
  51. self.recv(socket),
  52. self.send(socket, msgs),
  53. )
  54. if len(self.results) != self.sent_count:
  55. raise ValueError(
  56. f"Expected result count {len(self.results)}. Expected {self.sent_count}"
  57. )
  58. return self.results
  59. class AsyncClient(bb.asyncrpc.AsyncClient):
  60. MODE_NORMAL = 0
  61. MODE_GET_STREAM = 1
  62. MODE_EXIST_STREAM = 2
  63. MODE_MARK_STREAM = 3
  64. def __init__(self, username=None, password=None):
  65. super().__init__("OEHASHEQUIV", "1.1", logger)
  66. self.mode = self.MODE_NORMAL
  67. self.username = username
  68. self.password = password
  69. self.saved_become_user = None
  70. async def setup_connection(self):
  71. await super().setup_connection()
  72. self.mode = self.MODE_NORMAL
  73. if self.username:
  74. # Save off become user temporarily because auth() resets it
  75. become = self.saved_become_user
  76. await self.auth(self.username, self.password)
  77. if become:
  78. await self.become_user(become)
  79. async def send_stream_batch(self, mode, msgs):
  80. """
  81. Does a "batch" process of stream messages. This sends the query
  82. messages as fast as possible, and simultaneously attempts to read the
  83. messages back. This helps to mitigate the effects of latency to the
  84. hash equivalence server be allowing multiple queries to be "in-flight"
  85. at once
  86. The implementation does more complicated tracking using a count of sent
  87. messages so that `msgs` can be a generator function (i.e. its length is
  88. unknown)
  89. """
  90. b = Batch()
  91. async def proc():
  92. nonlocal b
  93. await self._set_mode(mode)
  94. return await b.process(self.socket, msgs)
  95. return await self._send_wrapper(proc)
  96. async def invoke(self, *args, skip_mode=False, **kwargs):
  97. # It's OK if connection errors cause a failure here, because the mode
  98. # is also reset to normal on a new connection
  99. if not skip_mode:
  100. await self._set_mode(self.MODE_NORMAL)
  101. return await super().invoke(*args, **kwargs)
  102. async def _set_mode(self, new_mode):
  103. async def stream_to_normal():
  104. # Check if already in normal mode (e.g. due to a connection reset)
  105. if self.mode == self.MODE_NORMAL:
  106. return "ok"
  107. await self.socket.send("END")
  108. return await self.socket.recv()
  109. async def normal_to_stream(command):
  110. r = await self.invoke({command: None}, skip_mode=True)
  111. if r != "ok":
  112. self.check_invoke_error(r)
  113. raise ConnectionError(
  114. f"Unable to transition to stream mode: Bad response from server {r!r}"
  115. )
  116. self.logger.debug("Mode is now %s", command)
  117. if new_mode == self.mode:
  118. return
  119. self.logger.debug("Transitioning mode %s -> %s", self.mode, new_mode)
  120. # Always transition to normal mode before switching to any other mode
  121. if self.mode != self.MODE_NORMAL:
  122. r = await self._send_wrapper(stream_to_normal)
  123. if r != "ok":
  124. self.check_invoke_error(r)
  125. raise ConnectionError(
  126. f"Unable to transition to normal mode: Bad response from server {r!r}"
  127. )
  128. self.logger.debug("Mode is now normal")
  129. if new_mode == self.MODE_GET_STREAM:
  130. await normal_to_stream("get-stream")
  131. elif new_mode == self.MODE_EXIST_STREAM:
  132. await normal_to_stream("exists-stream")
  133. elif new_mode == self.MODE_MARK_STREAM:
  134. await normal_to_stream("gc-mark-stream")
  135. elif new_mode != self.MODE_NORMAL:
  136. raise Exception("Undefined mode transition {self.mode!r} -> {new_mode!r}")
  137. self.mode = new_mode
  138. async def get_unihash(self, method, taskhash):
  139. r = await self.get_unihash_batch([(method, taskhash)])
  140. return r[0]
  141. async def get_unihash_batch(self, args):
  142. result = await self.send_stream_batch(
  143. self.MODE_GET_STREAM,
  144. (f"{method} {taskhash}" for method, taskhash in args),
  145. )
  146. return [r if r else None for r in result]
  147. async def report_unihash(self, taskhash, method, outhash, unihash, extra={}):
  148. m = extra.copy()
  149. m["taskhash"] = taskhash
  150. m["method"] = method
  151. m["outhash"] = outhash
  152. m["unihash"] = unihash
  153. return await self.invoke({"report": m})
  154. async def report_unihash_equiv(self, taskhash, method, unihash, extra={}):
  155. m = extra.copy()
  156. m["taskhash"] = taskhash
  157. m["method"] = method
  158. m["unihash"] = unihash
  159. return await self.invoke({"report-equiv": m})
  160. async def get_taskhash(self, method, taskhash, all_properties=False):
  161. return await self.invoke(
  162. {"get": {"taskhash": taskhash, "method": method, "all": all_properties}}
  163. )
  164. async def unihash_exists(self, unihash):
  165. r = await self.unihash_exists_batch([unihash])
  166. return r[0]
  167. async def unihash_exists_batch(self, unihashes):
  168. result = await self.send_stream_batch(self.MODE_EXIST_STREAM, unihashes)
  169. return [r == "true" for r in result]
  170. async def get_outhash(self, method, outhash, taskhash, with_unihash=True):
  171. return await self.invoke(
  172. {
  173. "get-outhash": {
  174. "outhash": outhash,
  175. "taskhash": taskhash,
  176. "method": method,
  177. "with_unihash": with_unihash,
  178. }
  179. }
  180. )
  181. async def get_stats(self):
  182. return await self.invoke({"get-stats": None})
  183. async def reset_stats(self):
  184. return await self.invoke({"reset-stats": None})
  185. async def backfill_wait(self):
  186. return (await self.invoke({"backfill-wait": None}))["tasks"]
  187. async def remove(self, where):
  188. return await self.invoke({"remove": {"where": where}})
  189. async def clean_unused(self, max_age):
  190. return await self.invoke({"clean-unused": {"max_age_seconds": max_age}})
  191. async def auth(self, username, token):
  192. result = await self.invoke({"auth": {"username": username, "token": token}})
  193. self.username = username
  194. self.password = token
  195. self.saved_become_user = None
  196. return result
  197. async def refresh_token(self, username=None):
  198. m = {}
  199. if username:
  200. m["username"] = username
  201. result = await self.invoke({"refresh-token": m})
  202. if (
  203. self.username
  204. and not self.saved_become_user
  205. and result["username"] == self.username
  206. ):
  207. self.password = result["token"]
  208. return result
  209. async def set_user_perms(self, username, permissions):
  210. return await self.invoke(
  211. {"set-user-perms": {"username": username, "permissions": permissions}}
  212. )
  213. async def get_user(self, username=None):
  214. m = {}
  215. if username:
  216. m["username"] = username
  217. return await self.invoke({"get-user": m})
  218. async def get_all_users(self):
  219. return (await self.invoke({"get-all-users": {}}))["users"]
  220. async def new_user(self, username, permissions):
  221. return await self.invoke(
  222. {"new-user": {"username": username, "permissions": permissions}}
  223. )
  224. async def delete_user(self, username):
  225. return await self.invoke({"delete-user": {"username": username}})
  226. async def become_user(self, username):
  227. result = await self.invoke({"become-user": {"username": username}})
  228. if username == self.username:
  229. self.saved_become_user = None
  230. else:
  231. self.saved_become_user = username
  232. return result
  233. async def get_db_usage(self):
  234. return (await self.invoke({"get-db-usage": {}}))["usage"]
  235. async def get_db_query_columns(self):
  236. return (await self.invoke({"get-db-query-columns": {}}))["columns"]
  237. async def gc_status(self):
  238. return await self.invoke({"gc-status": {}})
  239. async def gc_mark(self, mark, where):
  240. """
  241. Starts a new garbage collection operation identified by "mark". If
  242. garbage collection is already in progress with "mark", the collection
  243. is continued.
  244. All unihash entries that match the "where" clause are marked to be
  245. kept. In addition, any new entries added to the database after this
  246. command will be automatically marked with "mark"
  247. """
  248. return await self.invoke({"gc-mark": {"mark": mark, "where": where}})
  249. async def gc_mark_stream(self, mark, rows):
  250. """
  251. Similar to `gc-mark`, but accepts a list of "where" key-value pair
  252. conditions. It utilizes stream mode to mark hashes, which helps reduce
  253. the impact of latency when communicating with the hash equivalence
  254. server.
  255. """
  256. def row_to_dict(row):
  257. pairs = row.split()
  258. return dict(zip(pairs[::2], pairs[1::2]))
  259. responses = await self.send_stream_batch(
  260. self.MODE_MARK_STREAM,
  261. (json.dumps({"mark": mark, "where": row_to_dict(row)}) for row in rows),
  262. )
  263. return {"count": sum(int(json.loads(r)["count"]) for r in responses)}
  264. async def gc_sweep(self, mark):
  265. """
  266. Finishes garbage collection for "mark". All unihash entries that have
  267. not been marked will be deleted.
  268. It is recommended to clean unused outhash entries after running this to
  269. cleanup any dangling outhashes
  270. """
  271. return await self.invoke({"gc-sweep": {"mark": mark}})
  272. class Client(bb.asyncrpc.Client):
  273. def __init__(self, username=None, password=None):
  274. self.username = username
  275. self.password = password
  276. super().__init__()
  277. self._add_methods(
  278. "connect_tcp",
  279. "connect_websocket",
  280. "get_unihash",
  281. "get_unihash_batch",
  282. "report_unihash",
  283. "report_unihash_equiv",
  284. "get_taskhash",
  285. "unihash_exists",
  286. "unihash_exists_batch",
  287. "get_outhash",
  288. "get_stats",
  289. "reset_stats",
  290. "backfill_wait",
  291. "remove",
  292. "clean_unused",
  293. "auth",
  294. "refresh_token",
  295. "set_user_perms",
  296. "get_user",
  297. "get_all_users",
  298. "new_user",
  299. "delete_user",
  300. "become_user",
  301. "get_db_usage",
  302. "get_db_query_columns",
  303. "gc_status",
  304. "gc_mark",
  305. "gc_mark_stream",
  306. "gc_sweep",
  307. )
  308. def _get_async_client(self):
  309. return AsyncClient(self.username, self.password)