test_consumer.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. from kafkaesk import Application
  2. from kafkaesk import Subscription
  3. from kafkaesk.consumer import build_handler
  4. from kafkaesk.consumer import BatchConsumer, Subscription
  5. from kafkaesk.exceptions import ConsumerUnhealthyException
  6. from kafkaesk.exceptions import StopConsumer
  7. from kafkaesk.exceptions import UnhandledMessage
  8. from tests.utils import record_factory
  9. from unittest.mock import AsyncMock
  10. from unittest.mock import MagicMock
  11. from unittest.mock import Mock
  12. from unittest.mock import patch
  13. import aiokafka.errors
  14. import asyncio
  15. import opentracing
  16. import pydantic
  17. import pytest
  18. import time
  19. import json
  20. pytestmark = pytest.mark.asyncio
  21. @pytest.fixture()
  22. def subscription_conf():
  23. subscription = Subscription(
  24. "foo",
  25. lambda record: 1,
  26. "group",
  27. topics=["foo"],
  28. timeout_seconds=1,
  29. )
  30. yield subscription
  31. @pytest.fixture()
  32. def subscription(subscription_conf):
  33. yield BatchConsumer(
  34. subscription=subscription_conf,
  35. app=Application(kafka_servers=["foobar"]),
  36. )
  37. def test_subscription_repr():
  38. sub = Subscription("stream_id", lambda x: None, "group")
  39. assert repr(sub) == "<Subscription stream: stream_id >"
  40. class TestMessageHandler:
  41. def factory(self, func):
  42. return build_handler(func, app=MagicMock(), consumer=None)
  43. async def test_message_handler(self):
  44. side_effect = None
  45. async def raw_func(data):
  46. nonlocal side_effect
  47. assert isinstance(data, dict)
  48. side_effect = True
  49. handler = self.factory(raw_func)
  50. await handler(record_factory(), None)
  51. assert side_effect is True
  52. async def test_message_handler_map_types(self):
  53. class Foo(pydantic.BaseModel):
  54. foo: str
  55. async def handle_func(ob: Foo, schema, record, app, span: opentracing.Span):
  56. assert ob.foo == "bar"
  57. assert schema == "Foo:1"
  58. assert record is not None
  59. assert app is not None
  60. assert span is not None
  61. handler = self.factory(handle_func)
  62. await handler(record_factory(), MagicMock())
  63. async def test_malformed_message(self):
  64. class Foo(pydantic.BaseModel):
  65. foo: str
  66. side_effect = None
  67. async def func(ob: Foo):
  68. nonlocal side_effect
  69. side_effect = True
  70. record = aiokafka.structs.ConsumerRecord(
  71. topic="topic",
  72. partition=0,
  73. offset=0,
  74. timestamp=time.time() * 1000,
  75. timestamp_type=1,
  76. key="key",
  77. value=json.dumps({"schema": "Foo:1", "data": "bad format"}).encode(),
  78. checksum="1",
  79. serialized_key_size=10,
  80. serialized_value_size=10,
  81. headers=[],
  82. )
  83. handler = self.factory(func)
  84. with pytest.raises(UnhandledMessage):
  85. await handler(record, None)
  86. assert side_effect is None
  87. class TestSubscriptionConsumer:
  88. async def test_healthy(self, subscription):
  89. subscription._consumer = MagicMock()
  90. subscription._running = True
  91. subscription._consumer._coordinator.coordinator_id = "coordinator_id"
  92. subscription._consumer._client.ready = AsyncMock(return_value=True)
  93. assert await subscription.healthy() is None
  94. subscription._consumer._client.ready.assert_called_with("coordinator_id")
  95. async def test_unhealthy(self, subscription):
  96. subscription._consumer = MagicMock()
  97. subscription._running = True
  98. subscription._consumer._client.ready = AsyncMock(return_value=False)
  99. with pytest.raises(ConsumerUnhealthyException):
  100. assert await subscription.healthy()
  101. subscription._consumer = MagicMock()
  102. subscription._running = False
  103. with pytest.raises(ConsumerUnhealthyException):
  104. assert await subscription.healthy()
  105. async def test_emit(self, subscription_conf):
  106. probe = AsyncMock()
  107. sub = BatchConsumer(
  108. subscription=subscription_conf,
  109. app=Application(kafka_servers=["foobar"]),
  110. event_handlers={"event": [probe]},
  111. )
  112. await sub.emit("event", "foo", "bar")
  113. probe.assert_called_with("foo", "bar")
  114. async def test_emit_raises_stop(self, subscription_conf):
  115. sub = BatchConsumer(
  116. subscription=subscription_conf,
  117. app=Application(kafka_servers=["foobar"]),
  118. event_handlers={"event": [AsyncMock(side_effect=StopConsumer)]},
  119. )
  120. with pytest.raises(StopConsumer):
  121. await sub.emit("event", "foo", "bar")
  122. async def test_emit_swallow_ex(self, subscription_conf):
  123. sub = BatchConsumer(
  124. subscription=subscription_conf,
  125. app=Application(kafka_servers=["foobar"]),
  126. event_handlers={"event": [AsyncMock(side_effect=Exception)]},
  127. )
  128. await sub.emit("event", "foo", "bar")
  129. async def test_retries_on_connection_failure(self, subscription):
  130. run_mock = AsyncMock()
  131. sleep = AsyncMock()
  132. run_mock.side_effect = [aiokafka.errors.KafkaConnectionError, StopConsumer]
  133. subscription._consumer = MagicMock()
  134. with patch.object(subscription, "initialize", AsyncMock()), patch.object(
  135. subscription, "finalize", AsyncMock()
  136. ), patch.object(subscription, "_consume", run_mock), patch(
  137. "kafkaesk.consumer.asyncio.sleep", sleep
  138. ):
  139. await subscription()
  140. sleep.assert_called_once()
  141. assert len(run_mock.mock_calls) == 2
  142. async def test_finalize_handles_exceptions(self, subscription):
  143. consumer = AsyncMock()
  144. consumer.stop.side_effect = Exception
  145. consumer.commit.side_effect = Exception
  146. subscription._consumer = consumer
  147. await subscription.finalize()
  148. consumer.stop.assert_called_once()
  149. async def test_run_exits_when_fut_closed_fut(self, subscription):
  150. sub = subscription
  151. consumer = AsyncMock()
  152. consumer.getmany.return_value = {"": [record_factory() for _ in range(10)]}
  153. sub._consumer = consumer
  154. sub._running = True
  155. async def _handle_message(record):
  156. await asyncio.sleep(0.03)
  157. with patch.object(sub, "_handler", _handle_message):
  158. task = asyncio.create_task(sub._consume())
  159. await asyncio.sleep(0.01)
  160. stop_task = asyncio.create_task(sub.stop())
  161. await asyncio.sleep(0.01)
  162. sub._close.set_result(None)
  163. await asyncio.wait([stop_task, task])
  164. async def test_auto_commit_can_be_disabled(self, subscription_conf):
  165. sub = BatchConsumer(
  166. subscription=subscription_conf,
  167. app=Application(kafka_servers=["foobar"]),
  168. auto_commit=False,
  169. )
  170. await sub._maybe_commit()
  171. assert sub._last_commit == 0