test_bandwidth.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. # Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"). You
  4. # may not use this file except in compliance with the License. A copy of
  5. # the License is located at
  6. #
  7. # http://aws.amazon.com/apache2.0/
  8. #
  9. # or in the "license" file accompanying this file. This file is
  10. # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
  11. # ANY KIND, either express or implied. See the License for the specific
  12. # language governing permissions and limitations under the License.
  13. import os
  14. import shutil
  15. import tempfile
  16. from s3transfer.bandwidth import (
  17. BandwidthLimitedStream,
  18. BandwidthLimiter,
  19. BandwidthRateTracker,
  20. ConsumptionScheduler,
  21. LeakyBucket,
  22. RequestExceededException,
  23. RequestToken,
  24. TimeUtils,
  25. )
  26. from s3transfer.futures import TransferCoordinator
  27. from tests import mock, unittest
  28. class FixedIncrementalTickTimeUtils(TimeUtils):
  29. def __init__(self, seconds_per_tick=1.0):
  30. self._count = 0
  31. self._seconds_per_tick = seconds_per_tick
  32. def time(self):
  33. current_count = self._count
  34. self._count += self._seconds_per_tick
  35. return current_count
  36. class TestTimeUtils(unittest.TestCase):
  37. @mock.patch('time.time')
  38. def test_time(self, mock_time):
  39. mock_return_val = 1
  40. mock_time.return_value = mock_return_val
  41. time_utils = TimeUtils()
  42. self.assertEqual(time_utils.time(), mock_return_val)
  43. @mock.patch('time.sleep')
  44. def test_sleep(self, mock_sleep):
  45. time_utils = TimeUtils()
  46. time_utils.sleep(1)
  47. self.assertEqual(mock_sleep.call_args_list, [mock.call(1)])
  48. class BaseBandwidthLimitTest(unittest.TestCase):
  49. def setUp(self):
  50. self.leaky_bucket = mock.Mock(LeakyBucket)
  51. self.time_utils = mock.Mock(TimeUtils)
  52. self.tempdir = tempfile.mkdtemp()
  53. self.content = b'a' * 1024 * 1024
  54. self.filename = os.path.join(self.tempdir, 'myfile')
  55. with open(self.filename, 'wb') as f:
  56. f.write(self.content)
  57. self.coordinator = TransferCoordinator()
  58. def tearDown(self):
  59. shutil.rmtree(self.tempdir)
  60. def assert_consume_calls(self, amts):
  61. expected_consume_args = [mock.call(amt, mock.ANY) for amt in amts]
  62. self.assertEqual(
  63. self.leaky_bucket.consume.call_args_list, expected_consume_args
  64. )
  65. class TestBandwidthLimiter(BaseBandwidthLimitTest):
  66. def setUp(self):
  67. super().setUp()
  68. self.bandwidth_limiter = BandwidthLimiter(self.leaky_bucket)
  69. def test_get_bandwidth_limited_stream(self):
  70. with open(self.filename, 'rb') as f:
  71. stream = self.bandwidth_limiter.get_bandwith_limited_stream(
  72. f, self.coordinator
  73. )
  74. self.assertIsInstance(stream, BandwidthLimitedStream)
  75. self.assertEqual(stream.read(len(self.content)), self.content)
  76. self.assert_consume_calls(amts=[len(self.content)])
  77. def test_get_disabled_bandwidth_limited_stream(self):
  78. with open(self.filename, 'rb') as f:
  79. stream = self.bandwidth_limiter.get_bandwith_limited_stream(
  80. f, self.coordinator, enabled=False
  81. )
  82. self.assertIsInstance(stream, BandwidthLimitedStream)
  83. self.assertEqual(stream.read(len(self.content)), self.content)
  84. self.leaky_bucket.consume.assert_not_called()
  85. class TestBandwidthLimitedStream(BaseBandwidthLimitTest):
  86. def setUp(self):
  87. super().setUp()
  88. self.bytes_threshold = 1
  89. def tearDown(self):
  90. shutil.rmtree(self.tempdir)
  91. def get_bandwidth_limited_stream(self, f):
  92. return BandwidthLimitedStream(
  93. f,
  94. self.leaky_bucket,
  95. self.coordinator,
  96. self.time_utils,
  97. self.bytes_threshold,
  98. )
  99. def assert_sleep_calls(self, amts):
  100. expected_sleep_args_list = [mock.call(amt) for amt in amts]
  101. self.assertEqual(
  102. self.time_utils.sleep.call_args_list, expected_sleep_args_list
  103. )
  104. def get_unique_consume_request_tokens(self):
  105. return {
  106. call_args[0][1]
  107. for call_args in self.leaky_bucket.consume.call_args_list
  108. }
  109. def test_read(self):
  110. with open(self.filename, 'rb') as f:
  111. stream = self.get_bandwidth_limited_stream(f)
  112. data = stream.read(len(self.content))
  113. self.assertEqual(self.content, data)
  114. self.assert_consume_calls(amts=[len(self.content)])
  115. self.assert_sleep_calls(amts=[])
  116. def test_retries_on_request_exceeded(self):
  117. with open(self.filename, 'rb') as f:
  118. stream = self.get_bandwidth_limited_stream(f)
  119. retry_time = 1
  120. amt_requested = len(self.content)
  121. self.leaky_bucket.consume.side_effect = [
  122. RequestExceededException(amt_requested, retry_time),
  123. len(self.content),
  124. ]
  125. data = stream.read(len(self.content))
  126. self.assertEqual(self.content, data)
  127. self.assert_consume_calls(amts=[amt_requested, amt_requested])
  128. self.assert_sleep_calls(amts=[retry_time])
  129. def test_with_transfer_coordinator_exception(self):
  130. self.coordinator.set_exception(ValueError())
  131. with open(self.filename, 'rb') as f:
  132. stream = self.get_bandwidth_limited_stream(f)
  133. with self.assertRaises(ValueError):
  134. stream.read(len(self.content))
  135. def test_read_when_bandwidth_limiting_disabled(self):
  136. with open(self.filename, 'rb') as f:
  137. stream = self.get_bandwidth_limited_stream(f)
  138. stream.disable_bandwidth_limiting()
  139. data = stream.read(len(self.content))
  140. self.assertEqual(self.content, data)
  141. self.assertFalse(self.leaky_bucket.consume.called)
  142. def test_read_toggle_disable_enable_bandwidth_limiting(self):
  143. with open(self.filename, 'rb') as f:
  144. stream = self.get_bandwidth_limited_stream(f)
  145. stream.disable_bandwidth_limiting()
  146. data = stream.read(1)
  147. self.assertEqual(self.content[:1], data)
  148. self.assert_consume_calls(amts=[])
  149. stream.enable_bandwidth_limiting()
  150. data = stream.read(len(self.content) - 1)
  151. self.assertEqual(self.content[1:], data)
  152. self.assert_consume_calls(amts=[len(self.content) - 1])
  153. def test_seek(self):
  154. mock_fileobj = mock.Mock()
  155. stream = self.get_bandwidth_limited_stream(mock_fileobj)
  156. stream.seek(1)
  157. self.assertEqual(mock_fileobj.seek.call_args_list, [mock.call(1, 0)])
  158. def test_tell(self):
  159. mock_fileobj = mock.Mock()
  160. stream = self.get_bandwidth_limited_stream(mock_fileobj)
  161. stream.tell()
  162. self.assertEqual(mock_fileobj.tell.call_args_list, [mock.call()])
  163. def test_close(self):
  164. mock_fileobj = mock.Mock()
  165. stream = self.get_bandwidth_limited_stream(mock_fileobj)
  166. stream.close()
  167. self.assertEqual(mock_fileobj.close.call_args_list, [mock.call()])
  168. def test_context_manager(self):
  169. mock_fileobj = mock.Mock()
  170. stream = self.get_bandwidth_limited_stream(mock_fileobj)
  171. with stream as stream_handle:
  172. self.assertIs(stream_handle, stream)
  173. self.assertEqual(mock_fileobj.close.call_args_list, [mock.call()])
  174. def test_reuses_request_token(self):
  175. with open(self.filename, 'rb') as f:
  176. stream = self.get_bandwidth_limited_stream(f)
  177. stream.read(1)
  178. stream.read(1)
  179. self.assertEqual(len(self.get_unique_consume_request_tokens()), 1)
  180. def test_request_tokens_unique_per_stream(self):
  181. with open(self.filename, 'rb') as f:
  182. stream = self.get_bandwidth_limited_stream(f)
  183. stream.read(1)
  184. with open(self.filename, 'rb') as f:
  185. stream = self.get_bandwidth_limited_stream(f)
  186. stream.read(1)
  187. self.assertEqual(len(self.get_unique_consume_request_tokens()), 2)
  188. def test_call_consume_after_reaching_threshold(self):
  189. self.bytes_threshold = 2
  190. with open(self.filename, 'rb') as f:
  191. stream = self.get_bandwidth_limited_stream(f)
  192. self.assertEqual(stream.read(1), self.content[:1])
  193. self.assert_consume_calls(amts=[])
  194. self.assertEqual(stream.read(1), self.content[1:2])
  195. self.assert_consume_calls(amts=[2])
  196. def test_resets_after_reaching_threshold(self):
  197. self.bytes_threshold = 2
  198. with open(self.filename, 'rb') as f:
  199. stream = self.get_bandwidth_limited_stream(f)
  200. self.assertEqual(stream.read(2), self.content[:2])
  201. self.assert_consume_calls(amts=[2])
  202. self.assertEqual(stream.read(1), self.content[2:3])
  203. self.assert_consume_calls(amts=[2])
  204. def test_pending_bytes_seen_on_close(self):
  205. self.bytes_threshold = 2
  206. with open(self.filename, 'rb') as f:
  207. stream = self.get_bandwidth_limited_stream(f)
  208. self.assertEqual(stream.read(1), self.content[:1])
  209. self.assert_consume_calls(amts=[])
  210. stream.close()
  211. self.assert_consume_calls(amts=[1])
  212. def test_no_bytes_remaining_on(self):
  213. self.bytes_threshold = 2
  214. with open(self.filename, 'rb') as f:
  215. stream = self.get_bandwidth_limited_stream(f)
  216. self.assertEqual(stream.read(2), self.content[:2])
  217. self.assert_consume_calls(amts=[2])
  218. stream.close()
  219. # There should have been no more consume() calls made
  220. # as all bytes have been accounted for in the previous
  221. # consume() call.
  222. self.assert_consume_calls(amts=[2])
  223. def test_disable_bandwidth_limiting_with_pending_bytes_seen_on_close(self):
  224. self.bytes_threshold = 2
  225. with open(self.filename, 'rb') as f:
  226. stream = self.get_bandwidth_limited_stream(f)
  227. self.assertEqual(stream.read(1), self.content[:1])
  228. self.assert_consume_calls(amts=[])
  229. stream.disable_bandwidth_limiting()
  230. stream.close()
  231. self.assert_consume_calls(amts=[])
  232. def test_signal_transferring(self):
  233. with open(self.filename, 'rb') as f:
  234. stream = self.get_bandwidth_limited_stream(f)
  235. stream.signal_not_transferring()
  236. data = stream.read(1)
  237. self.assertEqual(self.content[:1], data)
  238. self.assert_consume_calls(amts=[])
  239. stream.signal_transferring()
  240. data = stream.read(len(self.content) - 1)
  241. self.assertEqual(self.content[1:], data)
  242. self.assert_consume_calls(amts=[len(self.content) - 1])
  243. class TestLeakyBucket(unittest.TestCase):
  244. def setUp(self):
  245. self.max_rate = 1
  246. self.time_now = 1.0
  247. self.time_utils = mock.Mock(TimeUtils)
  248. self.time_utils.time.return_value = self.time_now
  249. self.scheduler = mock.Mock(ConsumptionScheduler)
  250. self.scheduler.is_scheduled.return_value = False
  251. self.rate_tracker = mock.Mock(BandwidthRateTracker)
  252. self.leaky_bucket = LeakyBucket(
  253. self.max_rate, self.time_utils, self.rate_tracker, self.scheduler
  254. )
  255. def set_projected_rate(self, rate):
  256. self.rate_tracker.get_projected_rate.return_value = rate
  257. def set_retry_time(self, retry_time):
  258. self.scheduler.schedule_consumption.return_value = retry_time
  259. def assert_recorded_consumed_amt(self, expected_amt):
  260. self.assertEqual(
  261. self.rate_tracker.record_consumption_rate.call_args,
  262. mock.call(expected_amt, self.time_utils.time.return_value),
  263. )
  264. def assert_was_scheduled(self, amt, token):
  265. self.assertEqual(
  266. self.scheduler.schedule_consumption.call_args,
  267. mock.call(amt, token, amt / (self.max_rate)),
  268. )
  269. def assert_nothing_scheduled(self):
  270. self.assertFalse(self.scheduler.schedule_consumption.called)
  271. def assert_processed_request_token(self, request_token):
  272. self.assertEqual(
  273. self.scheduler.process_scheduled_consumption.call_args,
  274. mock.call(request_token),
  275. )
  276. def test_consume_under_max_rate(self):
  277. amt = 1
  278. self.set_projected_rate(self.max_rate / 2)
  279. self.assertEqual(self.leaky_bucket.consume(amt, RequestToken()), amt)
  280. self.assert_recorded_consumed_amt(amt)
  281. self.assert_nothing_scheduled()
  282. def test_consume_at_max_rate(self):
  283. amt = 1
  284. self.set_projected_rate(self.max_rate)
  285. self.assertEqual(self.leaky_bucket.consume(amt, RequestToken()), amt)
  286. self.assert_recorded_consumed_amt(amt)
  287. self.assert_nothing_scheduled()
  288. def test_consume_over_max_rate(self):
  289. amt = 1
  290. retry_time = 2.0
  291. self.set_projected_rate(self.max_rate + 1)
  292. self.set_retry_time(retry_time)
  293. request_token = RequestToken()
  294. try:
  295. self.leaky_bucket.consume(amt, request_token)
  296. self.fail('A RequestExceededException should have been thrown')
  297. except RequestExceededException as e:
  298. self.assertEqual(e.requested_amt, amt)
  299. self.assertEqual(e.retry_time, retry_time)
  300. self.assert_was_scheduled(amt, request_token)
  301. def test_consume_with_scheduled_retry(self):
  302. amt = 1
  303. self.set_projected_rate(self.max_rate + 1)
  304. self.scheduler.is_scheduled.return_value = True
  305. request_token = RequestToken()
  306. self.assertEqual(self.leaky_bucket.consume(amt, request_token), amt)
  307. # Nothing new should have been scheduled but the request token
  308. # should have been processed.
  309. self.assert_nothing_scheduled()
  310. self.assert_processed_request_token(request_token)
  311. class TestConsumptionScheduler(unittest.TestCase):
  312. def setUp(self):
  313. self.scheduler = ConsumptionScheduler()
  314. def test_schedule_consumption(self):
  315. token = RequestToken()
  316. consume_time = 5
  317. actual_wait_time = self.scheduler.schedule_consumption(
  318. 1, token, consume_time
  319. )
  320. self.assertEqual(consume_time, actual_wait_time)
  321. def test_schedule_consumption_for_multiple_requests(self):
  322. token = RequestToken()
  323. consume_time = 5
  324. actual_wait_time = self.scheduler.schedule_consumption(
  325. 1, token, consume_time
  326. )
  327. self.assertEqual(consume_time, actual_wait_time)
  328. other_consume_time = 3
  329. other_token = RequestToken()
  330. next_wait_time = self.scheduler.schedule_consumption(
  331. 1, other_token, other_consume_time
  332. )
  333. # This wait time should be the previous time plus its desired
  334. # wait time
  335. self.assertEqual(next_wait_time, consume_time + other_consume_time)
  336. def test_is_scheduled(self):
  337. token = RequestToken()
  338. consume_time = 5
  339. self.scheduler.schedule_consumption(1, token, consume_time)
  340. self.assertTrue(self.scheduler.is_scheduled(token))
  341. def test_is_not_scheduled(self):
  342. self.assertFalse(self.scheduler.is_scheduled(RequestToken()))
  343. def test_process_scheduled_consumption(self):
  344. token = RequestToken()
  345. consume_time = 5
  346. self.scheduler.schedule_consumption(1, token, consume_time)
  347. self.scheduler.process_scheduled_consumption(token)
  348. self.assertFalse(self.scheduler.is_scheduled(token))
  349. different_time = 7
  350. # The previous consume time should have no affect on the next wait tim
  351. # as it has been completed.
  352. self.assertEqual(
  353. self.scheduler.schedule_consumption(1, token, different_time),
  354. different_time,
  355. )
  356. class TestBandwidthRateTracker(unittest.TestCase):
  357. def setUp(self):
  358. self.alpha = 0.8
  359. self.rate_tracker = BandwidthRateTracker(self.alpha)
  360. def test_current_rate_at_initilizations(self):
  361. self.assertEqual(self.rate_tracker.current_rate, 0.0)
  362. def test_current_rate_after_one_recorded_point(self):
  363. self.rate_tracker.record_consumption_rate(1, 1)
  364. # There is no last time point to do a diff against so return a
  365. # current rate of 0.0
  366. self.assertEqual(self.rate_tracker.current_rate, 0.0)
  367. def test_current_rate(self):
  368. self.rate_tracker.record_consumption_rate(1, 1)
  369. self.rate_tracker.record_consumption_rate(1, 2)
  370. self.rate_tracker.record_consumption_rate(1, 3)
  371. self.assertEqual(self.rate_tracker.current_rate, 0.96)
  372. def test_get_projected_rate_at_initilizations(self):
  373. self.assertEqual(self.rate_tracker.get_projected_rate(1, 1), 0.0)
  374. def test_get_projected_rate(self):
  375. self.rate_tracker.record_consumption_rate(1, 1)
  376. self.rate_tracker.record_consumption_rate(1, 2)
  377. projected_rate = self.rate_tracker.get_projected_rate(1, 3)
  378. self.assertEqual(projected_rate, 0.96)
  379. self.rate_tracker.record_consumption_rate(1, 3)
  380. self.assertEqual(self.rate_tracker.current_rate, projected_rate)
  381. def test_get_projected_rate_for_same_timestamp(self):
  382. self.rate_tracker.record_consumption_rate(1, 1)
  383. self.assertEqual(
  384. self.rate_tracker.get_projected_rate(1, 1), float('inf')
  385. )