app.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679
  1. from .consumer import BatchConsumer
  2. from .consumer import Subscription
  3. from .exceptions import AppNotConfiguredException
  4. from .exceptions import ProducerUnhealthyException
  5. from .exceptions import SchemaConflictException
  6. from .exceptions import StopConsumer
  7. from .kafka import KafkaTopicManager
  8. from .metrics import NOERROR
  9. from .metrics import PRODUCER_TOPIC_OFFSET
  10. from .metrics import PUBLISHED_MESSAGES
  11. from .metrics import PUBLISHED_MESSAGES_TIME
  12. from .metrics import watch_kafka
  13. from .metrics import watch_publish
  14. from .utils import resolve_dotted_name
  15. from asyncio.futures import Future
  16. from functools import partial
  17. from opentracing.scope_managers.contextvars import ContextVarsScopeManager
  18. from pydantic import BaseModel
  19. from types import TracebackType
  20. from typing import Any
  21. from typing import Awaitable
  22. from typing import Callable
  23. from typing import cast
  24. from typing import Dict
  25. from typing import List
  26. from typing import Optional
  27. from typing import Tuple
  28. from typing import Type
  29. import aiokafka
  30. import aiokafka.errors
  31. import aiokafka.structs
  32. import argparse
  33. import asyncio
  34. import logging
  35. import opentracing
  36. import orjson
  37. import pydantic
  38. import signal
  39. import time
  40. logger = logging.getLogger("kafkaesk")
  41. class SchemaRegistration:
  42. def __init__(
  43. self,
  44. id: str,
  45. version: int,
  46. model: Type[pydantic.BaseModel],
  47. retention: Optional[int] = None,
  48. streams: Optional[List[str]] = None,
  49. ):
  50. self.id = id
  51. self.version = version
  52. self.model = model
  53. self.retention = retention
  54. self.streams = streams
  55. def __repr__(self) -> str:
  56. return f"<SchemaRegistration {self.id}, version: {self.version} >"
  57. def published_callback(topic: str, start_time: float, fut: Future) -> None:
  58. # Record the metrics
  59. finish_time = time.time()
  60. exception = fut.exception()
  61. if exception:
  62. error = str(exception.__class__.__name__)
  63. PUBLISHED_MESSAGES.labels(stream_id=topic, partition=-1, error=error).inc()
  64. else:
  65. metadata = fut.result()
  66. PUBLISHED_MESSAGES.labels(
  67. stream_id=topic, partition=metadata.partition, error=NOERROR
  68. ).inc()
  69. PRODUCER_TOPIC_OFFSET.labels(stream_id=topic, partition=metadata.partition).set(
  70. metadata.offset
  71. )
  72. PUBLISHED_MESSAGES_TIME.labels(stream_id=topic).observe(finish_time - start_time)
  73. _aiokafka_consumer_settings = (
  74. "fetch_max_wait_ms",
  75. "fetch_max_bytes",
  76. "fetch_min_bytes",
  77. "max_partition_fetch_bytes",
  78. "request_timeout_ms",
  79. "auto_offset_reset",
  80. "metadata_max_age_ms",
  81. "max_poll_interval_ms",
  82. "rebalance_timeout_ms",
  83. "session_timeout_ms",
  84. "heartbeat_interval_ms",
  85. "consumer_timeout_ms",
  86. "max_poll_records",
  87. "connections_max_idle_ms",
  88. "ssl_context",
  89. "security_protocol",
  90. "sasl_mechanism",
  91. "sasl_plain_username",
  92. "sasl_plain_password",
  93. )
  94. _aiokafka_producer_settings = (
  95. "metadata_max_age_ms",
  96. "request_timeout_ms",
  97. "max_batch_size",
  98. "max_request_size",
  99. "send_backoff_ms",
  100. "retry_backoff_ms",
  101. "ssl_context",
  102. "security_protocol",
  103. "sasl_mechanism",
  104. "sasl_plain_username",
  105. "sasl_plain_password",
  106. )
  107. class Router:
  108. """
  109. Application routing configuration.
  110. """
  111. def __init__(self) -> None:
  112. self._subscriptions: List[Subscription] = []
  113. self._schemas: Dict[str, SchemaRegistration] = {}
  114. self._event_handlers: Dict[str, List[Callable[[], Awaitable[None]]]] = {}
  115. @property
  116. def subscriptions(self) -> List[Subscription]:
  117. return self._subscriptions
  118. @property
  119. def schemas(self) -> Dict[str, SchemaRegistration]:
  120. return self._schemas
  121. @property
  122. def event_handlers(self) -> Dict[str, List[Callable[[], Awaitable[None]]]]:
  123. return self._event_handlers
  124. def on(self, name: str, handler: Callable[[], Awaitable[None]]) -> None:
  125. if name not in self._event_handlers:
  126. self._event_handlers[name] = []
  127. self._event_handlers[name].append(handler)
  128. def _subscribe(
  129. self,
  130. group: str,
  131. *,
  132. consumer_id: str = None,
  133. pattern: str = None,
  134. topics: List[str] = None,
  135. timeout_seconds: float = None,
  136. concurrency: int = None,
  137. ) -> Callable:
  138. def inner(func: Callable) -> Callable:
  139. # If there is no consumer_id use the group instead
  140. subscription = Subscription(
  141. consumer_id or group,
  142. func,
  143. group or func.__name__,
  144. pattern=pattern,
  145. topics=topics,
  146. concurrency=concurrency,
  147. timeout_seconds=timeout_seconds,
  148. )
  149. self._subscriptions.append(subscription)
  150. return func
  151. return inner
  152. def subscribe_to_topics(
  153. self,
  154. topics: List[str],
  155. group: str,
  156. *,
  157. timeout_seconds: float = None,
  158. concurrency: int = None,
  159. ) -> Callable:
  160. return self._subscribe(
  161. group=group,
  162. topics=topics,
  163. pattern=None,
  164. timeout_seconds=timeout_seconds,
  165. concurrency=concurrency,
  166. )
  167. def subscribe_to_pattern(
  168. self,
  169. pattern: str,
  170. group: str,
  171. *,
  172. timeout_seconds: float = None,
  173. concurrency: int = None,
  174. ) -> Callable:
  175. return self._subscribe(
  176. group=group,
  177. topics=None,
  178. pattern=pattern,
  179. timeout_seconds=timeout_seconds,
  180. concurrency=concurrency,
  181. )
  182. def subscribe(
  183. self,
  184. stream_id: str,
  185. group: str,
  186. *,
  187. timeout_seconds: float = None,
  188. concurrency: int = None,
  189. ) -> Callable:
  190. """Keep backwards compatibility"""
  191. return self._subscribe(
  192. group=group,
  193. topics=None,
  194. pattern=stream_id,
  195. timeout_seconds=timeout_seconds,
  196. concurrency=concurrency,
  197. )
  198. def schema(
  199. self,
  200. _id: Optional[str] = None,
  201. *,
  202. version: Optional[int] = None,
  203. retention: Optional[int] = None,
  204. streams: Optional[List[str]] = None,
  205. ) -> Callable:
  206. version = version or 1
  207. def inner(cls: Type[BaseModel]) -> Type[BaseModel]:
  208. if _id is None:
  209. type_id = cls.__name__
  210. else:
  211. type_id = _id
  212. key = f"{type_id}:{version}"
  213. reg = SchemaRegistration(
  214. id=type_id, version=version or 1, model=cls, retention=retention, streams=streams
  215. )
  216. if key in self._schemas:
  217. raise SchemaConflictException(self._schemas[key], reg)
  218. cls.__key__ = key # type: ignore
  219. self._schemas[key] = reg
  220. return cls
  221. return inner
  222. class Application(Router):
  223. """
  224. Application configuration
  225. """
  226. _producer: Optional[aiokafka.AIOKafkaProducer]
  227. def __init__(
  228. self,
  229. kafka_servers: Optional[List[str]] = None,
  230. topic_prefix: str = "",
  231. kafka_settings: Optional[Dict[str, Any]] = None,
  232. replication_factor: Optional[int] = None,
  233. kafka_api_version: str = "auto",
  234. auto_commit: bool = True,
  235. ):
  236. super().__init__()
  237. self._kafka_servers = kafka_servers
  238. self._kafka_settings = kafka_settings
  239. self._producer = None
  240. self._initialized = False
  241. self._locks: Dict[str, asyncio.Lock] = {}
  242. self._kafka_api_version = kafka_api_version
  243. self._topic_prefix = topic_prefix
  244. self._replication_factor = replication_factor
  245. self._topic_mng: Optional[KafkaTopicManager] = None
  246. self._subscription_consumers: List[BatchConsumer] = []
  247. self._subscription_consumers_tasks: List[asyncio.Task] = []
  248. self.auto_commit = auto_commit
  249. @property
  250. def kafka_settings(self) -> Dict[str, Any]:
  251. return self._kafka_settings or {}
  252. def mount(self, router: Router) -> None:
  253. self._subscriptions.extend(router.subscriptions)
  254. self._schemas.update(router.schemas)
  255. self._event_handlers.update(router.event_handlers)
  256. async def health_check(self) -> None:
  257. for subscription_consumer in self._subscription_consumers:
  258. await subscription_consumer.healthy()
  259. if not self.producer_healthy():
  260. raise ProducerUnhealthyException(self._producer) # type: ignore
  261. async def _call_event_handlers(self, name: str) -> None:
  262. handlers = self._event_handlers.get(name)
  263. if handlers is not None:
  264. for handler in handlers:
  265. await handler()
  266. @property
  267. def topic_mng(self) -> KafkaTopicManager:
  268. if self._topic_mng is None:
  269. self._topic_mng = KafkaTopicManager(
  270. cast(List[str], self._kafka_servers),
  271. self._topic_prefix,
  272. replication_factor=self._replication_factor,
  273. kafka_api_version=self._kafka_api_version,
  274. ssl_context=self.kafka_settings.get("ssl_context"),
  275. security_protocol=self.kafka_settings.get("security_protocol", "PLAINTEXT"),
  276. sasl_mechanism=self.kafka_settings.get("sasl_mechanism"),
  277. sasl_plain_username=self.kafka_settings.get("sasl_plain_username"),
  278. sasl_plain_password=self.kafka_settings.get("sasl_plain_password"),
  279. )
  280. return self._topic_mng
  281. def get_lock(self, name: str) -> asyncio.Lock:
  282. if name not in self._locks:
  283. self._locks[name] = asyncio.Lock()
  284. return self._locks[name]
  285. def configure(
  286. self,
  287. kafka_servers: Optional[List[str]] = None,
  288. topic_prefix: Optional[str] = None,
  289. kafka_settings: Optional[Dict[str, Any]] = None,
  290. api_version: Optional[str] = None,
  291. replication_factor: Optional[int] = None,
  292. ) -> None:
  293. if kafka_servers is not None:
  294. self._kafka_servers = kafka_servers
  295. if topic_prefix is not None:
  296. self._topic_prefix = topic_prefix
  297. if kafka_settings is not None:
  298. self._kafka_settings = kafka_settings
  299. if api_version is not None:
  300. self._kafka_api_version = api_version
  301. if replication_factor is not None:
  302. self._replication_factor = replication_factor
  303. @property
  304. def is_configured(self) -> bool:
  305. return bool(self._kafka_servers)
  306. async def publish_and_wait(
  307. self,
  308. stream_id: str,
  309. data: BaseModel,
  310. key: Optional[bytes] = None,
  311. headers: Optional[List[Tuple[str, bytes]]] = None,
  312. ) -> aiokafka.structs.ConsumerRecord:
  313. return await (await self.publish(stream_id, data, key, headers=headers))
  314. async def _maybe_create_topic(self, stream_id: str, data: BaseModel = None) -> None:
  315. topic_id = self.topic_mng.get_topic_id(stream_id)
  316. async with self.get_lock(stream_id):
  317. if not await self.topic_mng.topic_exists(topic_id):
  318. reg = None
  319. if data:
  320. reg = self.get_schema_reg(data)
  321. retention_ms = None
  322. if reg is not None and reg.retention is not None:
  323. retention_ms = reg.retention * 1000
  324. await self.topic_mng.create_topic(
  325. topic_id,
  326. replication_factor=self._replication_factor,
  327. retention_ms=retention_ms,
  328. )
  329. async def publish(
  330. self,
  331. stream_id: str,
  332. data: BaseModel,
  333. key: Optional[bytes] = None,
  334. headers: Optional[List[Tuple[str, bytes]]] = None,
  335. ) -> Awaitable[aiokafka.structs.ConsumerRecord]:
  336. if not self._initialized:
  337. async with self.get_lock("_"):
  338. await self.initialize()
  339. schema_key = getattr(data, "__key__", None)
  340. if schema_key not in self._schemas:
  341. # do not require key
  342. schema_key = f"{data.__class__.__name__}:1"
  343. data_ = data.dict()
  344. await self._maybe_create_topic(stream_id, data)
  345. return await self.raw_publish(
  346. stream_id, orjson.dumps({"schema": schema_key, "data": data_}), key, headers=headers
  347. )
  348. async def raw_publish(
  349. self,
  350. stream_id: str,
  351. data: bytes,
  352. key: Optional[bytes] = None,
  353. headers: Optional[List[Tuple[str, bytes]]] = None,
  354. ) -> Awaitable[aiokafka.structs.ConsumerRecord]:
  355. logger.debug(f"Sending kafka msg: {stream_id}")
  356. producer = await self._get_producer()
  357. tracer = opentracing.tracer
  358. if not headers:
  359. headers = []
  360. else:
  361. # this is just to check the headers shape
  362. try:
  363. for _, _ in headers:
  364. pass
  365. except ValueError:
  366. # We want to be resilient to malformated headers
  367. logger.exception(f"Malformed headers: '{headers}'")
  368. if isinstance(tracer.scope_manager, ContextVarsScopeManager):
  369. # This only makes sense if the context manager is asyncio aware
  370. if tracer.active_span:
  371. carrier: Dict[str, str] = {}
  372. tracer.inject(
  373. span_context=tracer.active_span,
  374. format=opentracing.Format.TEXT_MAP,
  375. carrier=carrier,
  376. )
  377. header_keys = [k for k, _ in headers]
  378. for k, v in carrier.items():
  379. # Dont overwrite if they are already present!
  380. if k not in header_keys:
  381. headers.append((k, v.encode()))
  382. if not self.producer_healthy():
  383. raise ProducerUnhealthyException(self._producer) # type: ignore
  384. topic_id = self.topic_mng.get_topic_id(stream_id)
  385. start_time = time.time()
  386. with watch_publish(topic_id):
  387. fut = await producer.send(
  388. topic_id,
  389. value=data,
  390. key=key,
  391. headers=headers,
  392. )
  393. fut.add_done_callback(partial(published_callback, topic_id, start_time)) # type: ignore
  394. return fut
  395. async def flush(self) -> None:
  396. if self._producer is not None:
  397. await self._producer.flush()
  398. def get_schema_reg(self, model_or_def: BaseModel) -> Optional[SchemaRegistration]:
  399. try:
  400. key = model_or_def.__key__ # type: ignore
  401. return self._schemas[key]
  402. except (AttributeError, KeyError):
  403. return None
  404. def producer_healthy(self) -> bool:
  405. """
  406. It's possible for the producer to be unhealthy while we're still sending messages to it.
  407. """
  408. if self._producer is not None and self._producer._sender.sender_task is not None:
  409. return not self._producer._sender.sender_task.done()
  410. return True
  411. def consumer_factory(self, group_id: str) -> aiokafka.AIOKafkaConsumer:
  412. return aiokafka.AIOKafkaConsumer(
  413. bootstrap_servers=cast(List[str], self._kafka_servers),
  414. loop=asyncio.get_event_loop(),
  415. group_id=group_id,
  416. auto_offset_reset="earliest",
  417. api_version=self._kafka_api_version,
  418. enable_auto_commit=False,
  419. **{k: v for k, v in self.kafka_settings.items() if k in _aiokafka_consumer_settings},
  420. )
  421. def producer_factory(self) -> aiokafka.AIOKafkaProducer:
  422. return aiokafka.AIOKafkaProducer(
  423. bootstrap_servers=cast(List[str], self._kafka_servers),
  424. loop=asyncio.get_event_loop(),
  425. api_version=self._kafka_api_version,
  426. **{k: v for k, v in self.kafka_settings.items() if k in _aiokafka_producer_settings},
  427. )
  428. async def _get_producer(self) -> aiokafka.AIOKafkaProducer:
  429. if self._producer is None:
  430. self._producer = self.producer_factory()
  431. with watch_kafka("producer_start"):
  432. await self._producer.start()
  433. return self._producer
  434. async def initialize(self) -> None:
  435. if not self.is_configured:
  436. raise AppNotConfiguredException
  437. await self._call_event_handlers("initialize")
  438. for reg in self._schemas.values():
  439. # initialize topics for known streams
  440. for stream_id in reg.streams or []:
  441. topic_id = self.topic_mng.get_topic_id(stream_id)
  442. async with self.get_lock(stream_id):
  443. if not await self.topic_mng.topic_exists(topic_id):
  444. await self.topic_mng.create_topic(
  445. topic_id,
  446. retention_ms=reg.retention * 1000
  447. if reg.retention is not None
  448. else None,
  449. )
  450. self._initialized = True
  451. async def finalize(self) -> None:
  452. await self._call_event_handlers("finalize")
  453. await self.stop()
  454. if self._producer is not None:
  455. with watch_kafka("producer_flush"):
  456. await self._producer.flush()
  457. with watch_kafka("producer_stop"):
  458. await self._producer.stop()
  459. if self._topic_mng is not None:
  460. await self._topic_mng.finalize()
  461. self._producer = None
  462. self._initialized = False
  463. self._topic_mng = None
  464. async def __aenter__(self) -> "Application":
  465. await self.initialize()
  466. return self
  467. async def __aexit__(
  468. self,
  469. exc_type: Optional[Type[BaseException]],
  470. exc: Optional[BaseException],
  471. traceback: Optional[TracebackType],
  472. ) -> None:
  473. logger.info("Stopping application...", exc_info=exc)
  474. await self.finalize()
  475. async def consume_for(self, num_messages: int, *, seconds: Optional[int] = None) -> int:
  476. consumed = 0
  477. self._subscription_consumers = []
  478. tasks = []
  479. for subscription in self._subscriptions:
  480. async def on_message(record: aiokafka.structs.ConsumerRecord) -> None:
  481. nonlocal consumed
  482. consumed += 1
  483. if consumed >= num_messages:
  484. raise StopConsumer
  485. consumer = BatchConsumer(
  486. subscription=subscription,
  487. app=self,
  488. event_handlers={"message": [on_message]},
  489. auto_commit=self.auto_commit,
  490. )
  491. self._subscription_consumers.append(consumer)
  492. tasks.append(asyncio.create_task(consumer(), name=str(consumer)))
  493. done, pending = await asyncio.wait(
  494. tasks, timeout=seconds, return_when=asyncio.FIRST_EXCEPTION
  495. )
  496. await self.stop()
  497. # re-raise any errors so we can validate during tests
  498. for task in done:
  499. exc = task.exception()
  500. if exc is not None:
  501. raise exc
  502. for task in pending:
  503. task.cancel()
  504. return consumed
  505. def consume_forever(self) -> Awaitable:
  506. self._subscription_consumers = []
  507. self._subscription_consumers_tasks = []
  508. for subscription in self._subscriptions:
  509. consumer = BatchConsumer(
  510. subscription=subscription,
  511. app=self,
  512. auto_commit=self.auto_commit,
  513. )
  514. self._subscription_consumers.append(consumer)
  515. self._subscription_consumers_tasks = [
  516. asyncio.create_task(c()) for c in self._subscription_consumers
  517. ]
  518. return asyncio.wait(self._subscription_consumers_tasks, return_when=asyncio.FIRST_EXCEPTION)
  519. async def stop(self) -> None:
  520. async with self.get_lock("_"):
  521. # do not allow stop calls at same time
  522. if len(self._subscription_consumers) == 0:
  523. return
  524. _, pending = await asyncio.wait(
  525. [c.stop() for c in self._subscription_consumers if c], timeout=5
  526. )
  527. for task in pending:
  528. # stop tasks that didn't finish
  529. task.cancel()
  530. for task in self._subscription_consumers_tasks:
  531. # make sure everything is done
  532. if not task.done():
  533. task.cancel()
  534. for task in self._subscription_consumers_tasks:
  535. try:
  536. await asyncio.wait([task])
  537. except asyncio.CancelledError:
  538. ...
  539. cli_parser = argparse.ArgumentParser(description="Run kafkaesk worker.")
  540. cli_parser.add_argument("app", help="Application object")
  541. cli_parser.add_argument("--kafka-servers", help="Kafka servers")
  542. cli_parser.add_argument("--kafka-settings", help="Kafka settings")
  543. cli_parser.add_argument("--topic-prefix", help="Topic prefix")
  544. cli_parser.add_argument("--api-version", help="Kafka API Version")
  545. def _sig_handler(app: Application) -> None:
  546. asyncio.create_task(app.stop())
  547. async def run_app(app: Application) -> None:
  548. async with app:
  549. loop = asyncio.get_event_loop()
  550. fut = asyncio.create_task(app.consume_forever())
  551. for signame in {"SIGINT", "SIGTERM"}:
  552. loop.add_signal_handler(getattr(signal, signame), partial(_sig_handler, app))
  553. done, pending = await fut
  554. logger.debug("Exiting consumer")
  555. await app.stop()
  556. # re-raise any errors so we can validate during tests
  557. for task in done:
  558. exc = task.exception()
  559. if exc is not None:
  560. raise exc
  561. def run(app: Optional[Application] = None) -> None:
  562. if app is None:
  563. opts = cli_parser.parse_args()
  564. module_str, attr = opts.app.split(":")
  565. module = resolve_dotted_name(module_str)
  566. app = getattr(module, attr)
  567. if callable(app):
  568. app = app()
  569. app = cast(Application, app)
  570. if opts.kafka_servers:
  571. app.configure(kafka_servers=opts.kafka_servers.split(","))
  572. if opts.kafka_settings:
  573. app.configure(kafka_settings=orjson.loads(opts.kafka_settings))
  574. if opts.topic_prefix:
  575. app.configure(topic_prefix=opts.topic_prefix)
  576. if opts.api_version:
  577. app.configure(api_version=opts.api_version)
  578. try:
  579. asyncio.run(run_app(app))
  580. except asyncio.CancelledError: # pragma: no cover
  581. logger.debug("Closing because task was exited")