sqlite_s3_query.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. from contextlib import contextmanager
  2. from ctypes import CFUNCTYPE, POINTER, Structure, create_string_buffer, pointer, cast, memmove, memset, sizeof, addressof, cdll, byref, string_at, c_char_p, c_int, c_double, c_int64, c_void_p, c_char
  3. from ctypes.util import find_library
  4. from functools import partial
  5. from hashlib import sha256
  6. import hmac
  7. from datetime import datetime
  8. import os
  9. from re import sub
  10. from time import time
  11. from urllib.parse import urlencode, urlsplit, quote
  12. from uuid import uuid4
  13. import httpx
  14. @contextmanager
  15. def sqlite_s3_query_multi(url, get_credentials=lambda now: (
  16. os.environ['AWS_REGION'],
  17. os.environ['AWS_ACCESS_KEY_ID'],
  18. os.environ['AWS_SECRET_ACCESS_KEY'],
  19. os.environ.get('AWS_SESSION_TOKEN'), # Only needed for temporary credentials
  20. ), get_http_client=lambda: httpx.Client(),
  21. get_libsqlite3=lambda: cdll.LoadLibrary(find_library('sqlite3'))):
  22. libsqlite3 = get_libsqlite3()
  23. libsqlite3.sqlite3_errstr.restype = c_char_p
  24. libsqlite3.sqlite3_errmsg.restype = c_char_p
  25. libsqlite3.sqlite3_column_name.restype = c_char_p
  26. libsqlite3.sqlite3_column_double.restype = c_double
  27. libsqlite3.sqlite3_column_int64.restype = c_int64
  28. libsqlite3.sqlite3_column_blob.restype = c_void_p
  29. libsqlite3.sqlite3_column_bytes.restype = c_int64
  30. SQLITE_OK = 0
  31. SQLITE_IOERR = 10
  32. SQLITE_NOTFOUND = 12
  33. SQLITE_ROW = 100
  34. SQLITE_DONE = 101
  35. SQLITE_TRANSIENT = -1
  36. SQLITE_OPEN_READONLY = 0x00000001
  37. SQLITE_IOCAP_IMMUTABLE = 0x00002000
  38. bind = {
  39. type(0): libsqlite3.sqlite3_bind_int64,
  40. type(0.0): libsqlite3.sqlite3_bind_double,
  41. type(''): lambda pp_stmt, i, value: libsqlite3.sqlite3_bind_text(pp_stmt, i, value.encode('utf-8'), len(value.encode('utf-8')), SQLITE_TRANSIENT),
  42. type(b''): lambda pp_stmt, i, value: libsqlite3.sqlite3_bind_blob(pp_stmt, i, value, len(value), SQLITE_TRANSIENT),
  43. type(None): lambda pp_stmt, i, _: libsqlite3.sqlite3_bind_null(pp_stmt, i),
  44. }
  45. extract = {
  46. 1: libsqlite3.sqlite3_column_int64,
  47. 2: libsqlite3.sqlite3_column_double,
  48. 3: lambda pp_stmt, i: string_at(
  49. libsqlite3.sqlite3_column_blob(pp_stmt, i),
  50. libsqlite3.sqlite3_column_bytes(pp_stmt, i),
  51. ).decode(),
  52. 4: lambda pp_stmt, i: string_at(
  53. libsqlite3.sqlite3_column_blob(pp_stmt, i),
  54. libsqlite3.sqlite3_column_bytes(pp_stmt, i),
  55. ),
  56. 5: lambda pp_stmt, i: None,
  57. }
  58. vfs_name = b's3-' + str(uuid4()).encode()
  59. file_name = b's3-' + str(uuid4()).encode()
  60. body_hash = sha256(b'').hexdigest()
  61. scheme, netloc, path, _, _ = urlsplit(url)
  62. def run(func, *args):
  63. res = func(*args)
  64. if res != 0:
  65. raise Exception(libsqlite3.sqlite3_errstr(res).decode())
  66. def run_with_db(db, func, *args):
  67. if func(*args) != 0:
  68. raise Exception(libsqlite3.sqlite3_errmsg(db).decode())
  69. @contextmanager
  70. def make_auth_request(http_client, method, params, headers):
  71. now = datetime.utcnow()
  72. region, access_key_id, secret_access_key, session_token = get_credentials(now)
  73. to_auth_headers = headers + (
  74. (('x-amz-security-token', session_token),) if session_token is not None else \
  75. ()
  76. )
  77. request_headers = aws_sigv4_headers(
  78. now, access_key_id, secret_access_key, region, method, to_auth_headers, params,
  79. )
  80. url = f'{scheme}://{netloc}{path}'
  81. with http_client.stream(method, url, params=params, headers=request_headers) as response:
  82. response.raise_for_status()
  83. yield response
  84. def aws_sigv4_headers(
  85. now, access_key_id, secret_access_key, region, method, headers_to_sign, params,
  86. ):
  87. def sign(key, msg):
  88. return hmac.new(key, msg.encode('ascii'), sha256).digest()
  89. algorithm = 'AWS4-HMAC-SHA256'
  90. amzdate = now.strftime('%Y%m%dT%H%M%SZ')
  91. datestamp = amzdate[:8]
  92. credential_scope = f'{datestamp}/{region}/s3/aws4_request'
  93. headers = tuple(sorted(headers_to_sign + (
  94. ('host', netloc),
  95. ('x-amz-content-sha256', body_hash),
  96. ('x-amz-date', amzdate),
  97. )))
  98. signed_headers = ';'.join(key for key, _ in headers)
  99. canonical_uri = quote(path, safe='/~')
  100. quoted_params = sorted(
  101. (quote(key, safe='~'), quote(value, safe='~'))
  102. for key, value in params
  103. )
  104. canonical_querystring = '&'.join(f'{key}={value}' for key, value in quoted_params)
  105. canonical_headers = ''.join(f'{key}:{value}\n' for key, value in headers)
  106. canonical_request = f'{method}\n{canonical_uri}\n{canonical_querystring}\n' + \
  107. f'{canonical_headers}\n{signed_headers}\n{body_hash}'
  108. string_to_sign = f'{algorithm}\n{amzdate}\n{credential_scope}\n' + \
  109. sha256(canonical_request.encode('ascii')).hexdigest()
  110. date_key = sign(('AWS4' + secret_access_key).encode('ascii'), datestamp)
  111. region_key = sign(date_key, region)
  112. service_key = sign(region_key, 's3')
  113. request_key = sign(service_key, 'aws4_request')
  114. signature = sign(request_key, string_to_sign).hex()
  115. return (
  116. ('authorization', (
  117. f'{algorithm} Credential={access_key_id}/{credential_scope}, '
  118. f'SignedHeaders={signed_headers}, Signature={signature}')
  119. ),
  120. ) + headers
  121. @contextmanager
  122. def get_vfs(http_client):
  123. with make_auth_request(http_client, 'HEAD', (), ()) as response:
  124. head_headers = response.headers
  125. next(response.iter_bytes(), b'')
  126. try:
  127. version_id = head_headers['x-amz-version-id']
  128. except KeyError:
  129. raise Exception('The bucket must have versioning enabled')
  130. size = int(head_headers['content-length'])
  131. def make_struct(fields):
  132. class Struct(Structure):
  133. _fields_ = [(field_name, field_type) for (field_name, field_type, _) in fields]
  134. return Struct(*tuple(value for (_, _, value) in fields))
  135. x_open_type = CFUNCTYPE(c_int, c_void_p, c_char_p, c_void_p, c_int, POINTER(c_int))
  136. def x_open(p_vfs, z_name, p_file, flags, p_out_flags):
  137. memmove(p_file, addressof(file), sizeof(file))
  138. p_out_flags[0] = flags
  139. return SQLITE_OK
  140. x_close_type = CFUNCTYPE(c_int, c_void_p)
  141. def x_close(p_file):
  142. return SQLITE_OK
  143. x_read_type = CFUNCTYPE(c_int, c_void_p, c_void_p, c_int, c_int64)
  144. def x_read(p_file, p_out, i_amt, i_ofst):
  145. offset = 0
  146. try:
  147. with make_auth_request(http_client, 'GET',
  148. (('versionId', version_id),),
  149. (('range', f'bytes={i_ofst}-{i_ofst + i_amt - 1}'),)
  150. ) as response:
  151. # Handle the case of the server being broken or slightly evil,
  152. # returning more than the number of bytes that's asked for
  153. for chunk in response.iter_bytes():
  154. memmove(p_out + offset, chunk, min(i_amt - offset, len(chunk)))
  155. offset += len(chunk)
  156. if offset > i_amt:
  157. break
  158. except Exception:
  159. return SQLITE_IOERR
  160. if offset != i_amt:
  161. return SQLITE_IOERR
  162. return SQLITE_OK
  163. x_file_size_type = CFUNCTYPE(c_int, c_void_p, POINTER(c_int64))
  164. def x_file_size(p_file, p_size):
  165. p_size[0] = size
  166. return SQLITE_OK
  167. x_lock_type = CFUNCTYPE(c_int, c_void_p, c_int)
  168. def x_lock(p_file, e_lock):
  169. return SQLITE_OK
  170. x_unlock_type = CFUNCTYPE(c_int, c_void_p, c_int)
  171. def x_unlock(p_file, e_lock):
  172. return SQLITE_OK
  173. x_file_control_type = CFUNCTYPE(c_int, c_void_p, c_int, c_void_p)
  174. def x_file_control(p_file, op, p_arg):
  175. return SQLITE_NOTFOUND
  176. x_device_characteristics_type = CFUNCTYPE(c_int, c_void_p)
  177. def x_device_characteristics(p_file):
  178. return SQLITE_IOCAP_IMMUTABLE
  179. x_access_type = CFUNCTYPE(c_int, c_void_p, c_char_p, c_int, POINTER(c_int))
  180. def x_access(p_vfs, z_name, flags, z_out):
  181. z_out[0] = 0
  182. return SQLITE_OK
  183. x_full_pathname_type = CFUNCTYPE(c_int, c_void_p, c_char_p, c_int, POINTER(c_char))
  184. def x_full_pathname(p_vfs, z_name, n_out, z_out):
  185. memmove(z_out, z_name, len(z_name) + 1)
  186. return SQLITE_OK
  187. x_current_time_type = CFUNCTYPE(c_int, c_void_p, POINTER(c_double))
  188. def x_current_time(p_vfs, c_double_p):
  189. c_double_p[0] = time()/86400.0 + 2440587.5;
  190. return SQLITE_OK
  191. io_methods = make_struct((
  192. ('i_version', c_int, 1),
  193. ('x_close', x_close_type, x_close_type(x_close)),
  194. ('x_read', x_read_type, x_read_type(x_read)),
  195. ('x_write', c_void_p, None),
  196. ('x_truncate', c_void_p, None),
  197. ('x_sync', c_void_p, None),
  198. ('x_file_size', x_file_size_type, x_file_size_type(x_file_size)),
  199. ('x_lock', x_lock_type, x_lock_type(x_lock)),
  200. ('x_unlock', x_unlock_type, x_unlock_type(x_unlock)),
  201. ('x_check_reserved_lock', c_void_p, None),
  202. ('x_file_control', x_file_control_type, x_file_control_type(x_file_control)),
  203. ('x_sector_size', c_void_p, None),
  204. ('x_device_characteristics', x_device_characteristics_type, x_device_characteristics_type(x_device_characteristics)),
  205. ))
  206. file = make_struct((
  207. ('p_methods', POINTER(type(io_methods)), pointer(io_methods)),
  208. ))
  209. vfs = make_struct((
  210. ('i_version', c_int, 1),
  211. ('sz_os_file', c_int, sizeof(file)),
  212. ('mx_pathname', c_int, 1024),
  213. ('p_next', c_void_p, None),
  214. ('z_name', c_char_p, vfs_name),
  215. ('p_app_data', c_char_p, None),
  216. ('x_open', x_open_type, x_open_type(x_open)),
  217. ('x_delete', c_void_p, None),
  218. ('x_access', x_access_type, x_access_type(x_access)),
  219. ('x_full_pathname', x_full_pathname_type, x_full_pathname_type(x_full_pathname)),
  220. ('x_dl_open', c_void_p, None),
  221. ('x_dl_error', c_void_p, None),
  222. ('x_dl_sym', c_void_p, None),
  223. ('x_dl_close', c_void_p, None),
  224. ('x_randomness', c_void_p, None),
  225. ('x_sleep', c_void_p, None),
  226. ('x_current_time', x_current_time_type, x_current_time_type(x_current_time)),
  227. ('x_get_last_error', c_void_p, None),
  228. ))
  229. run(libsqlite3.sqlite3_vfs_register, byref(vfs), 0)
  230. try:
  231. yield vfs
  232. finally:
  233. run(libsqlite3.sqlite3_vfs_unregister, byref(vfs))
  234. @contextmanager
  235. def get_db(vfs):
  236. db = c_void_p()
  237. run(libsqlite3.sqlite3_open_v2, file_name, byref(db), SQLITE_OPEN_READONLY, vfs_name)
  238. try:
  239. yield db
  240. finally:
  241. run_with_db(db, libsqlite3.sqlite3_close, db)
  242. @contextmanager
  243. def get_pp_stmt_getter(db):
  244. # The purpose of this context manager is to make sure we finalize statements before
  245. # attempting to close the database, including in the case of unfinished interation
  246. statements = {}
  247. def get_pp_stmt(statement):
  248. try:
  249. return statements[statement]
  250. except KeyError:
  251. raise Exception('Attempting to use finalized statement') from None
  252. def finalize(statement):
  253. # In case there are errors, don't attempt to re-finalize the same statement
  254. try:
  255. pp_stmt = statements.pop(statement)
  256. except KeyError:
  257. return
  258. try:
  259. run_with_db(db, libsqlite3.sqlite3_finalize, pp_stmt)
  260. except:
  261. # The only case found where this errored is when we've already had an error due to
  262. # a malformed disk image, which will already bubble up to client code
  263. pass
  264. def get_pp_stmts(sql):
  265. p_encoded = POINTER(c_char)(create_string_buffer(sql.encode()))
  266. while True:
  267. pp_stmt = c_void_p()
  268. run_with_db(db, libsqlite3.sqlite3_prepare_v2, db, p_encoded, -1, byref(pp_stmt), byref(p_encoded))
  269. if not pp_stmt:
  270. break
  271. # c_void_p is not hashable, and there is a theoretical possibility that multiple
  272. # exist at the same time pointing to the same memory, so use a plain object instead
  273. statement = object()
  274. statements[statement] = pp_stmt
  275. yield partial(get_pp_stmt, statement), partial(finalize, statement)
  276. try:
  277. yield get_pp_stmts
  278. finally:
  279. for statement in statements.copy().keys():
  280. finalize(statement)
  281. def rows(get_pp_stmt, columns):
  282. while True:
  283. pp_stmt = get_pp_stmt()
  284. res = libsqlite3.sqlite3_step(pp_stmt)
  285. if res == SQLITE_DONE:
  286. break
  287. if res != SQLITE_ROW:
  288. raise Exception(libsqlite3.sqlite3_errstr(res).decode())
  289. yield tuple(
  290. extract[libsqlite3.sqlite3_column_type(pp_stmt, i)](pp_stmt, i)
  291. for i in range(0, len(columns))
  292. )
  293. def query(db, get_pp_stmts, sql, params=()):
  294. for get_pp_stmt, finalize_stmt in get_pp_stmts(sql):
  295. try:
  296. pp_stmt = get_pp_stmt()
  297. for i, param in enumerate(params):
  298. run_with_db(db, bind[type(param)], pp_stmt, i + 1, param)
  299. columns = tuple(
  300. libsqlite3.sqlite3_column_name(pp_stmt, i).decode()
  301. for i in range(0, libsqlite3.sqlite3_column_count(pp_stmt))
  302. )
  303. yield columns, rows(get_pp_stmt, columns)
  304. finally:
  305. finalize_stmt()
  306. with \
  307. get_http_client() as http_client, \
  308. get_vfs(http_client) as vfs, \
  309. get_db(vfs) as db, \
  310. get_pp_stmt_getter(db) as get_pp_stmts:
  311. yield partial(query, db, get_pp_stmts)
  312. @contextmanager
  313. def sqlite_s3_query(url, get_credentials=lambda now: (
  314. os.environ['AWS_REGION'],
  315. os.environ['AWS_ACCESS_KEY_ID'],
  316. os.environ['AWS_SECRET_ACCESS_KEY'],
  317. os.environ.get('AWS_SESSION_TOKEN'), # Only needed for temporary credentials
  318. ), get_http_client=lambda: httpx.Client(),
  319. get_libsqlite3=lambda: cdll.LoadLibrary(find_library('sqlite3'))):
  320. @contextmanager
  321. def query(query_base, sql, params=()):
  322. for columns, rows in query_base(sql, params):
  323. yield columns, rows
  324. break
  325. with sqlite_s3_query_multi(url,
  326. get_credentials=get_credentials,
  327. get_http_client=get_http_client,
  328. get_libsqlite3=get_libsqlite3,
  329. ) as query_base:
  330. yield partial(query, query_base)