multipart_upload_job.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. import logging
  2. from .utils import _threads, _chunk_by_size, MIN_S3_SIZE
  3. logger = logging.getLogger(__name__)
  4. class MultipartUploadJob:
  5. def __init__(self, bucket, result_filepath, data_input,
  6. s3,
  7. small_parts_threads=1,
  8. add_part_number=True,
  9. content_type='application/octet-stream'):
  10. # s3 cannot be a class var because the Pool cannot pickle it
  11. # threading support comming soon
  12. self.bucket = bucket
  13. self.part_number, self.parts_list = data_input
  14. self.content_type = content_type
  15. self.small_parts_threads = small_parts_threads
  16. if add_part_number:
  17. if '.' in result_filepath.split('/')[-1]:
  18. # If there is a file extention, put the part number before it
  19. path_parts = result_filepath.rsplit('.', 1)
  20. self.result_filepath = '{}-{}.{}'.format(path_parts[0],
  21. self.part_number,
  22. path_parts[1])
  23. else:
  24. self.result_filepath = '{}-{}'.format(result_filepath,
  25. self.part_number)
  26. else:
  27. self.result_filepath = result_filepath
  28. if len(self.parts_list) == 1:
  29. # Perform a simple S3 copy since there is just a single file
  30. source_file = "{}/{}".format(self.bucket, self.parts_list[0][0])
  31. resp = s3.copy_object(Bucket=self.bucket,
  32. CopySource=source_file,
  33. Key=self.result_filepath)
  34. msg = "Copied single file to {}".format(self.result_filepath)
  35. if logger.getEffectiveLevel() == logging.DEBUG:
  36. logger.debug("{}, got response: {}".format(msg, resp))
  37. else:
  38. logger.info(msg)
  39. elif len(self.parts_list) > 1:
  40. self.upload_id = self._start_multipart_upload(s3)
  41. parts_mapping = self._assemble_parts(s3)
  42. self._complete_concatenation(s3, parts_mapping)
  43. def _start_multipart_upload(self, s3):
  44. resp = s3.create_multipart_upload(Bucket=self.bucket,
  45. Key=self.result_filepath,
  46. ContentType=self.content_type)
  47. msg = "Started multipart upload for {}".format(self.result_filepath)
  48. if logger.getEffectiveLevel() == logging.DEBUG:
  49. logger.debug("{}, got response: {}".format(msg, resp))
  50. else:
  51. logger.info(msg)
  52. return resp['UploadId']
  53. def _assemble_parts(self, s3):
  54. # TODO: Thread the loops in this function
  55. parts_mapping = []
  56. part_num = 0
  57. s3_parts = ["{}/{}".format(self.bucket, p[0])
  58. for p in self.parts_list if p[1] > MIN_S3_SIZE]
  59. local_parts = [p for p in self.parts_list if p[1] <= MIN_S3_SIZE]
  60. # assemble parts large enough for direct S3 copy
  61. for part_num, source_part in enumerate(s3_parts, 1):
  62. resp = s3.upload_part_copy(Bucket=self.bucket,
  63. Key=self.result_filepath,
  64. PartNumber=part_num,
  65. UploadId=self.upload_id,
  66. CopySource=source_part)
  67. msg = "Setup S3 part #{}, with path: {}".format(part_num,
  68. source_part)
  69. logger.debug("{}, got response: {}".format(msg, resp))
  70. # ceph doesn't return quoted etags
  71. etag = (resp['CopyPartResult']['ETag']
  72. .replace("'", "").replace("\"", ""))
  73. parts_mapping.append({'ETag': etag, 'PartNumber': part_num})
  74. # assemble parts too small for direct S3 copy by downloading them,
  75. # combining them, and then reuploading them as the last part of the
  76. # multi-part upload (which is not constrained to the 5mb limit)
  77. # Concat the small_parts into the minium size then upload
  78. # this way not to much data is kept in memory
  79. def get_small_parts(data):
  80. part_num, part = data
  81. small_part_count = len(part[1])
  82. logger.debug("Start sub-part #{} from {} files"
  83. .format(part_num, small_part_count))
  84. small_parts = []
  85. for p in part[1]:
  86. try:
  87. small_parts.append(
  88. s3.get_object(
  89. Bucket=self.bucket,
  90. Key=p[0]
  91. )['Body'].read()
  92. )
  93. except Exception as e:
  94. logger.critical(
  95. f"{e}: When getting {p[0]} from the bucket {self.bucket}") # noqa: E501
  96. raise
  97. if len(small_parts) > 0:
  98. last_part = b''.join(small_parts)
  99. small_parts = None # cleanup
  100. resp = s3.upload_part(Bucket=self.bucket,
  101. Key=self.result_filepath,
  102. PartNumber=part_num,
  103. UploadId=self.upload_id,
  104. Body=last_part)
  105. msg = "Finish sub-part #{} from {} files"\
  106. .format(part_num, small_part_count)
  107. logger.debug("{}, got response: {}".format(msg, resp))
  108. last_part = None
  109. # Handles both quoted and unquoted etags
  110. etag = resp['ETag'].replace("'", "").replace("\"", "")
  111. return {'ETag': etag,
  112. 'PartNumber': part_num}
  113. return {}
  114. data_to_thread = []
  115. for idx, data in enumerate(_chunk_by_size(local_parts,
  116. MIN_S3_SIZE * 2),
  117. start=1):
  118. data_to_thread.append([part_num + idx, data])
  119. parts_mapping.extend(
  120. _threads(self.small_parts_threads,
  121. data_to_thread,
  122. get_small_parts)
  123. )
  124. # Sort part mapping by part number
  125. return sorted(parts_mapping, key=lambda i: i['PartNumber'])
  126. def _complete_concatenation(self, s3, parts_mapping):
  127. if len(parts_mapping) == 0:
  128. s3.abort_multipart_upload(Bucket=self.bucket,
  129. Key=self.result_filepath,
  130. UploadId=self.upload_id)
  131. warn_msg = ("Aborted concatenation for file {}, with upload"
  132. " id #{} due to empty parts mapping")
  133. logger.error(warn_msg.format(self.result_filepath,
  134. self.upload_id))
  135. else:
  136. parts = {'Parts': parts_mapping}
  137. s3.complete_multipart_upload(Bucket=self.bucket,
  138. Key=self.result_filepath,
  139. UploadId=self.upload_id,
  140. MultipartUpload=parts)
  141. warn_msg = ("Finished concatenation for file {},"
  142. " with upload id #{}")
  143. logger.info(warn_msg.format(self.result_filepath,
  144. self.upload_id))