test_cloud_tasks.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. import pytest
  2. from google.api_core.exceptions import GoogleAPICallError, NotFound, RetryError
  3. from accession.cloud_tasks import (
  4. AwsCredentials,
  5. AwsS3Object,
  6. CloudTasksUploadClient,
  7. QueueInfo,
  8. UploadPayload,
  9. )
  10. from accession.file import GSFile
  11. @pytest.fixture
  12. def aws_credentials():
  13. return AwsCredentials(
  14. aws_access_key_id="foo", aws_secret_access_key="bar", aws_session_token="baz"
  15. )
  16. @pytest.fixture
  17. def aws_s3_object():
  18. return AwsS3Object(bucket="foo", key="bar")
  19. @pytest.fixture
  20. def upload_payload(mocker, aws_credentials, aws_s3_object):
  21. blob = mocker.Mock()
  22. blob.name = "name"
  23. blob.bucket.name = "bucket"
  24. mocker.patch("accession.file.GSFile.blob", mocker.PropertyMock(return_value=blob))
  25. gs_file = GSFile(key="foo", name="gs://bucket/name")
  26. return UploadPayload(
  27. aws_credentials=aws_credentials, aws_s3_object=aws_s3_object, gcs_blob=gs_file
  28. )
  29. @pytest.fixture
  30. def cloud_tasks_upload_client(mocker):
  31. mocker.patch(
  32. "accession.cloud_tasks.CloudTasksUploadClient.client",
  33. new_callable=mocker.PropertyMock(),
  34. )
  35. mocker.patch(
  36. "accession.cloud_tasks.CloudTasksUploadClient.logger",
  37. new_callable=mocker.PropertyMock(),
  38. )
  39. client = CloudTasksUploadClient(
  40. QueueInfo(region="us-west1", name="queue"), no_log_file=True
  41. )
  42. return client
  43. def test_queue_info_from_env(mocker):
  44. mocker.patch.dict(
  45. "os.environ",
  46. {
  47. "ACCESSION_CLOUD_TASKS_QUEUE_NAME": "foo",
  48. "ACCESSION_CLOUD_TASKS_QUEUE_REGION": "bar",
  49. },
  50. )
  51. result = QueueInfo.from_env()
  52. assert result == QueueInfo(name="foo", region="bar")
  53. def test_queue_info_from_env_env_vars_not_set_returns_none(mocker):
  54. mocker.patch.dict("os.environ", {"ACCESSION_CLOUD_TASKS_QUEUE_NAME": "foo"})
  55. result = QueueInfo.from_env()
  56. assert result is None
  57. def test_aws_credentials_get_dict(aws_credentials):
  58. result = aws_credentials.get_dict()
  59. assert result == {
  60. "aws_access_key_id": "foo",
  61. "aws_secret_access_key": "bar",
  62. "aws_session_token": "baz",
  63. }
  64. def test_aws_s3_object_get_dict(aws_s3_object):
  65. result = aws_s3_object.get_dict()
  66. assert result == {"bucket": "foo", "key": "bar"}
  67. def test_upload_payload_get_dict(upload_payload):
  68. result = upload_payload.get_dict()
  69. assert result["aws_s3_object"]
  70. assert result["aws_credentials"]
  71. assert result["gcs_blob"]["bucket"] == "bucket"
  72. assert result["gcs_blob"]["name"] == "name"
  73. def test_upload_payload_get_bytes(upload_payload):
  74. result = upload_payload.get_bytes()
  75. assert result.startswith(b'{"aws_credentials"')
  76. def test_upload_payload_get_task_id(upload_payload):
  77. result = upload_payload.get_task_id()
  78. assert result == "127ba03ad0ee8f4fcfe64a9172507e66"
  79. def test_cloud_tasks_upload_client_project_id(mocker, cloud_tasks_upload_client):
  80. mocker.patch("google.auth.default", return_value=("foo", "project-id"))
  81. result = cloud_tasks_upload_client.project_id
  82. assert result == "project-id"
  83. def test_cloud_tasks_upload_client_project_id_google_auth_returns_none_raises(
  84. mocker, cloud_tasks_upload_client
  85. ):
  86. mocker.patch("google.auth.default", return_value=("foo", None))
  87. with pytest.raises(ValueError):
  88. _ = cloud_tasks_upload_client.project_id
  89. def test_cloud_tasks_upload_client_get_queue_path(mocker, cloud_tasks_upload_client):
  90. mocker.patch("google.auth.default", return_value=("foo", "project-id"))
  91. cloud_tasks_upload_client.get_queue_path()
  92. assert cloud_tasks_upload_client.client.queue_path.called_once_with(
  93. "project-id", "us-west1", "queue"
  94. )
  95. def test_cloud_tasks_upload_client_validate_queue_info(
  96. mocker, cloud_tasks_upload_client
  97. ):
  98. mocker.patch.object(cloud_tasks_upload_client, "get_queue_path")
  99. cloud_tasks_upload_client.validate_queue_info()
  100. cloud_tasks_upload_client.get_queue_path.assert_called_once()
  101. cloud_tasks_upload_client.client.get_queue.assert_called_once()
  102. def test_cloud_tasks_upload_client_validate_queue_info_raises(
  103. mocker, cloud_tasks_upload_client
  104. ):
  105. mocker.patch.object(cloud_tasks_upload_client, "get_queue_path")
  106. cloud_tasks_upload_client.client.get_queue.side_effect = NotFound(message="failed")
  107. with pytest.raises(ValueError):
  108. cloud_tasks_upload_client.validate_queue_info()
  109. def test_cloud_tasks_upload_client_get_task_name(
  110. mocker, cloud_tasks_upload_client, upload_payload
  111. ):
  112. mocker.patch.object(cloud_tasks_upload_client, "get_queue_path", return_value="foo")
  113. mocker.patch.object(upload_payload, "get_task_id", return_value="123")
  114. result = cloud_tasks_upload_client._get_task_name(upload_payload)
  115. assert result == "foo/tasks/123"
  116. def test_cloud_tasks_upload_client_upload(
  117. mocker, cloud_tasks_upload_client, upload_payload
  118. ):
  119. mocker.patch.object(cloud_tasks_upload_client, "_submit_task")
  120. cloud_tasks_upload_client.upload(upload_payload)
  121. assert cloud_tasks_upload_client._submit_task.called_once_with(
  122. "/upload", upload_payload
  123. )
  124. def test_cloud_tasks_upload_client_submit_task(
  125. mocker, cloud_tasks_upload_client, upload_payload
  126. ):
  127. mocker.patch.object(
  128. cloud_tasks_upload_client, "get_queue_path", return_value="queue-path"
  129. )
  130. cloud_tasks_upload_client._submit_task("/endpoint", upload_payload)
  131. assert cloud_tasks_upload_client.client.create_task.called_once_with(
  132. "queue-path", upload_payload
  133. )
  134. assert cloud_tasks_upload_client.logger.info.call_args[0][2] == "/endpoint"
  135. def test_cloud_tasks_upload_client_submit_task_value_error(
  136. mocker, cloud_tasks_upload_client, upload_payload
  137. ):
  138. mocker.patch.object(
  139. cloud_tasks_upload_client, "get_queue_path", return_value="queue-path"
  140. )
  141. mocker.patch.object(
  142. cloud_tasks_upload_client.client, "create_task", side_effect=ValueError("error")
  143. )
  144. with pytest.raises(ValueError):
  145. try:
  146. cloud_tasks_upload_client._submit_task("/endpoint", upload_payload)
  147. finally:
  148. assert cloud_tasks_upload_client.logger.exception.called_once()
  149. def test_cloud_tasks_upload_client_submit_task_google_api_call_error(
  150. mocker, cloud_tasks_upload_client, upload_payload
  151. ):
  152. mocker.patch.object(
  153. cloud_tasks_upload_client, "get_queue_path", return_value="queue-path"
  154. )
  155. mocker.patch.object(
  156. cloud_tasks_upload_client.client,
  157. "create_task",
  158. side_effect=GoogleAPICallError(message="foo"),
  159. )
  160. with pytest.raises(GoogleAPICallError):
  161. try:
  162. cloud_tasks_upload_client._submit_task("/endpoint", upload_payload)
  163. finally:
  164. assert cloud_tasks_upload_client.logger.exception.call_args[0][1] == "foo"
  165. def test_cloud_tasks_upload_client_submit_task_retry_error(
  166. mocker, cloud_tasks_upload_client, upload_payload
  167. ):
  168. mocker.patch.object(
  169. cloud_tasks_upload_client, "get_queue_path", return_value="queue-path"
  170. )
  171. mocker.patch.object(
  172. cloud_tasks_upload_client.client,
  173. "create_task",
  174. side_effect=RetryError(message="foo", cause=Exception("bar")),
  175. )
  176. with pytest.raises(RetryError):
  177. try:
  178. cloud_tasks_upload_client._submit_task("/endpoint", upload_payload)
  179. finally:
  180. assert cloud_tasks_upload_client.logger.exception.call_args[0][1] == "foo"