test_app.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. from asyncio.futures import Future
  2. from kafkaesk.app import Application
  3. from kafkaesk.app import published_callback
  4. from kafkaesk.app import run
  5. from kafkaesk.app import run_app
  6. from kafkaesk.app import SchemaRegistration
  7. from jaeger_client import Config, Tracer
  8. from opentracing.scope_managers.contextvars import ContextVarsScopeManager
  9. from tests.utils import record_factory
  10. from unittest.mock import ANY
  11. from unittest.mock import AsyncMock
  12. from unittest.mock import MagicMock
  13. from unittest.mock import Mock
  14. from unittest.mock import patch
  15. import asyncio
  16. import json
  17. import kafkaesk
  18. import kafkaesk.exceptions
  19. import opentracing
  20. import pydantic
  21. import pytest
  22. import time
  23. pytestmark = pytest.mark.asyncio
  24. class TestApplication:
  25. async def test_app_events(self):
  26. app = Application()
  27. async def on_finalize():
  28. pass
  29. app.on("finalize", on_finalize)
  30. assert len(app._event_handlers["finalize"]) == 1
  31. async def test_app_finalize_event(self):
  32. app = Application()
  33. class CallTracker:
  34. def __init__(self):
  35. self.called = False
  36. async def on_finalize(self):
  37. self.called = True
  38. tracker = CallTracker()
  39. app.on("finalize", tracker.on_finalize)
  40. await app.finalize()
  41. assert tracker.called is True
  42. def test_publish_callback(self, metrics):
  43. fut = Future()
  44. fut.set_result(record_factory())
  45. published_callback("topic", time.time() - 1, fut)
  46. metrics["PUBLISHED_MESSAGES"].labels.assert_called_with(
  47. stream_id="topic", partition=0, error="none"
  48. )
  49. metrics["PUBLISHED_MESSAGES"].labels().inc()
  50. metrics["PRODUCER_TOPIC_OFFSET"].labels.assert_called_with(stream_id="topic", partition=0)
  51. metrics["PRODUCER_TOPIC_OFFSET"].labels().set.assert_called_with(0)
  52. metrics["PUBLISHED_MESSAGES_TIME"].labels.assert_called_with(stream_id="topic")
  53. assert metrics["PUBLISHED_MESSAGES_TIME"].labels().observe.mock_calls[0].args[
  54. 0
  55. ] == pytest.approx(1, 0.1)
  56. def test_publish_callback_exc(self, metrics):
  57. fut = Future()
  58. fut.set_exception(Exception())
  59. published_callback("topic", time.time(), fut)
  60. metrics["PUBLISHED_MESSAGES"].labels.assert_called_with(
  61. stream_id="topic", partition=-1, error="Exception"
  62. )
  63. metrics["PUBLISHED_MESSAGES"].labels().inc()
  64. def test_mount_router(self):
  65. app = Application()
  66. router = kafkaesk.Router()
  67. @router.schema("Foo", streams=["foo.bar"])
  68. class Foo(pydantic.BaseModel):
  69. bar: str
  70. @router.subscribe("foo.bar", group="test_group")
  71. async def consume(data: Foo, schema, record):
  72. ...
  73. app.mount(router)
  74. assert app.subscriptions == router.subscriptions
  75. assert app.schemas == router.schemas
  76. assert app.event_handlers == router.event_handlers
  77. async def test_consumer_health_check(self):
  78. app = kafkaesk.Application()
  79. subscription_consumer = AsyncMock()
  80. app._subscription_consumers.append(subscription_consumer)
  81. subscription_consumer.consumer._client.ready.return_value = True
  82. await app.health_check()
  83. async def test_consumer_health_check_raises_exception(self):
  84. app = kafkaesk.Application()
  85. subscription = kafkaesk.Subscription(
  86. "test_consumer", lambda record: 1, "group", topics=["foo"]
  87. )
  88. subscription_consumer = kafkaesk.BatchConsumer(
  89. subscription=subscription,
  90. app=app,
  91. )
  92. app._subscription_consumers.append(subscription_consumer)
  93. subscription_consumer._consumer = AsyncMock()
  94. subscription_consumer._consumer._client.ready.return_value = False
  95. with pytest.raises(kafkaesk.exceptions.ConsumerUnhealthyException):
  96. await app.health_check()
  97. async def test_consumer_health_check_producer_healthy(self):
  98. app = kafkaesk.Application()
  99. app._producer = MagicMock()
  100. app._producer._sender.sender_task.done.return_value = False
  101. await app.health_check()
  102. async def test_consumer_health_check_producer_unhealthy(self):
  103. app = kafkaesk.Application()
  104. app._producer = MagicMock()
  105. app._producer._sender.sender_task.done.return_value = True
  106. with pytest.raises(kafkaesk.exceptions.ProducerUnhealthyException):
  107. await app.health_check()
  108. async def test_configure_kafka_producer(self):
  109. app = kafkaesk.Application(
  110. kafka_settings={
  111. "metadata_max_age_ms": 100,
  112. "max_batch_size": 100,
  113. # invalid for producer so should not be applied here
  114. "max_partition_fetch_bytes": 100,
  115. }
  116. )
  117. # verify it is created correctly
  118. app.producer_factory()
  119. # now, validate the wiring
  120. with patch("kafkaesk.app.aiokafka.AIOKafkaProducer") as mock:
  121. app.producer_factory()
  122. mock.assert_called_with(
  123. bootstrap_servers=None,
  124. loop=ANY,
  125. api_version="auto",
  126. metadata_max_age_ms=100,
  127. max_batch_size=100,
  128. )
  129. async def test_configure_kafka_consumer(self):
  130. app = kafkaesk.Application(
  131. kafka_settings={
  132. "max_partition_fetch_bytes": 100,
  133. "fetch_max_wait_ms": 100,
  134. "metadata_max_age_ms": 100,
  135. # invalid for consumer so should not be applied here
  136. "max_batch_size": 100,
  137. }
  138. )
  139. # verify it is created correctly
  140. app.consumer_factory(group_id="foobar")
  141. # now, validate the wiring
  142. with patch("kafkaesk.app.aiokafka.AIOKafkaConsumer") as mock:
  143. app.consumer_factory(group_id="foobar")
  144. mock.assert_called_with(
  145. bootstrap_servers=None,
  146. loop=ANY,
  147. group_id="foobar",
  148. api_version="auto",
  149. auto_offset_reset="earliest",
  150. enable_auto_commit=False,
  151. max_partition_fetch_bytes=100,
  152. fetch_max_wait_ms=100,
  153. metadata_max_age_ms=100,
  154. )
  155. def test_configure(self):
  156. app = kafkaesk.Application()
  157. app.configure(
  158. kafka_servers=["kafka_servers"],
  159. topic_prefix="topic_prefix",
  160. kafka_settings={"kafka_settings": "kafka_settings"},
  161. api_version="api_version",
  162. replication_factor="replication_factor",
  163. )
  164. assert app._kafka_servers == ["kafka_servers"]
  165. assert app._topic_prefix == "topic_prefix"
  166. assert app._kafka_settings == {"kafka_settings": "kafka_settings"}
  167. assert app._kafka_api_version == "api_version"
  168. assert app._replication_factor == "replication_factor"
  169. # now make sure none values do not overwrite
  170. app.configure(
  171. kafka_servers=None,
  172. topic_prefix=None,
  173. kafka_settings=None,
  174. api_version=None,
  175. replication_factor=None,
  176. )
  177. assert app._kafka_servers == ["kafka_servers"]
  178. assert app._topic_prefix == "topic_prefix"
  179. assert app._kafka_settings == {"kafka_settings": "kafka_settings"}
  180. assert app._kafka_api_version == "api_version"
  181. assert app._replication_factor == "replication_factor"
  182. async def test_initialize_with_unconfigured_app_raises_exception(self):
  183. app = kafkaesk.Application()
  184. with pytest.raises(kafkaesk.exceptions.AppNotConfiguredException):
  185. await app.initialize()
  186. async def test_publish_propagates_headers(self):
  187. app = kafkaesk.Application(kafka_servers=["foo"])
  188. class Foo(pydantic.BaseModel):
  189. bar: str
  190. producer = AsyncMock()
  191. producer.send.return_value = fut = asyncio.Future()
  192. fut.set_result("ok")
  193. app._get_producer = AsyncMock(return_value=producer)
  194. app._topic_mng = MagicMock()
  195. app._topic_mng.get_topic_id.return_value = "foobar"
  196. app._topic_mng.topic_exists = AsyncMock(return_value=True)
  197. future = await app.publish("foobar", Foo(bar="foo"), headers=[("foo", b"bar")])
  198. _ = await future
  199. producer.send.assert_called_with(
  200. "foobar",
  201. value=b'{"schema":"Foo:1","data":{"bar":"foo"}}',
  202. key=None,
  203. headers=[("foo", b"bar")],
  204. )
  205. async def test_publish_configured_retention_policy(self):
  206. app = kafkaesk.Application(kafka_servers=["foo"])
  207. @app.schema(retention=100)
  208. class Foo(pydantic.BaseModel):
  209. bar: str
  210. producer = AsyncMock()
  211. producer.send.return_value = fut = asyncio.Future()
  212. fut.set_result("ok")
  213. app._get_producer = AsyncMock(return_value=producer)
  214. app._topic_mng = MagicMock()
  215. app._topic_mng.get_topic_id.return_value = "foobar"
  216. app._topic_mng.topic_exists = AsyncMock(return_value=False)
  217. app._topic_mng.create_topic = AsyncMock()
  218. future = await app.publish("foobar", Foo(bar="foo"), headers=[("foo", b"bar")])
  219. await future
  220. app._topic_mng.create_topic.assert_called_with(
  221. "foobar", replication_factor=None, retention_ms=100 * 1000
  222. )
  223. async def test_publish_injects_tracing(self):
  224. app = kafkaesk.Application(kafka_servers=["foo"])
  225. producer = AsyncMock()
  226. producer.send.return_value = fut = asyncio.Future()
  227. fut.set_result("ok")
  228. app._get_producer = AsyncMock(return_value=producer)
  229. config = Config(
  230. config={"sampler": {"type": "const", "param": 1}, "logging": True, "propagation": "b3"},
  231. service_name="test_service",
  232. scope_manager=ContextVarsScopeManager(),
  233. )
  234. # this call also sets opentracing.tracer
  235. tracer = config.initialize_tracer()
  236. span = tracer.start_span(operation_name="dummy")
  237. tracer.scope_manager.activate(span, True)
  238. future = await app.raw_publish("foobar", b"foobar")
  239. await future
  240. headers = producer.mock_calls[0].kwargs["headers"]
  241. assert str(span).startswith(headers[0][1].decode())
  242. class TestSchemaRegistration:
  243. def test_schema_registration_repr(self):
  244. reg = SchemaRegistration(id="id", version=1, model=None)
  245. assert repr(reg) == "<SchemaRegistration id, version: 1 >"
  246. test_app = Application()
  247. def app_callable():
  248. return test_app
  249. class TestRun:
  250. def test_run(self):
  251. rapp = AsyncMock()
  252. with patch("kafkaesk.app.run_app", rapp), patch("kafkaesk.app.cli_parser") as cli_parser:
  253. args = Mock()
  254. args.app = "tests.unit.test_app:test_app"
  255. args.kafka_servers = "foo,bar"
  256. args.kafka_settings = json.dumps({"foo": "bar"})
  257. args.topic_prefix = "prefix"
  258. args.api_version = "api_version"
  259. cli_parser.parse_args.return_value = args
  260. run()
  261. rapp.assert_called_once()
  262. assert test_app._kafka_servers == ["foo", "bar"]
  263. assert test_app._kafka_settings == {"foo": "bar"}
  264. assert test_app._topic_prefix == "prefix"
  265. assert test_app._kafka_api_version == "api_version"
  266. def test_run_callable(self):
  267. rapp = AsyncMock()
  268. with patch("kafkaesk.app.run_app", rapp), patch("kafkaesk.app.cli_parser") as cli_parser:
  269. args = Mock()
  270. args.app = "tests.unit.test_app:app_callable"
  271. args.kafka_settings = None
  272. cli_parser.parse_args.return_value = args
  273. run()
  274. rapp.assert_called_once()
  275. async def test_run_app(self):
  276. app_mock = AsyncMock()
  277. app_mock.consume_forever.return_value = (set(), set())
  278. loop = MagicMock()
  279. with patch("kafkaesk.app.asyncio.get_event_loop", return_value=loop):
  280. await run_app(app_mock)
  281. app_mock.consume_forever.assert_called_once()
  282. assert len(loop.add_signal_handler.mock_calls) == 2