s3-mp-copy.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. #!/usr/bin/env python
  2. import argparse
  3. from cStringIO import StringIO
  4. import logging
  5. from math import ceil
  6. from multiprocessing import Pool
  7. import sys
  8. import time
  9. import urlparse
  10. import boto
  11. from boto.s3.connection import OrdinaryCallingFormat
  12. parser = argparse.ArgumentParser(description="Copy large files within S3",
  13. prog="s3-mp-copy")
  14. parser.add_argument("src", help="The S3 source object")
  15. parser.add_argument("dest", help="The S3 destination object")
  16. parser.add_argument("-np", "--num-processes", help="Number of processors to use",
  17. type=int, default=2)
  18. parser.add_argument("-f", "--force", help="Overwrite an existing S3 key",
  19. action="store_true")
  20. parser.add_argument("-s", "--split", help="Split size, in Mb", type=int, default=50)
  21. parser.add_argument("-rrs", "--reduced-redundancy", help="Use reduced redundancy storage. Default is standard.",
  22. default=False, action="store_true")
  23. parser.add_argument("-v", "--verbose", help="Be more verbose", default=False, action="store_true")
  24. logger = logging.getLogger("s3-mp-copy")
  25. def do_part_copy(args):
  26. """
  27. Copy a part of a MultiPartUpload
  28. Copy a single chunk between S3 objects. Since we can't pickle
  29. S3Connection or MultiPartUpload objects, we have to reconnect and lookup
  30. the MPU object with each part upload.
  31. :type args: tuple of (string, string, string, int, int, int, int)
  32. :param args: The actual arguments of this method. Due to lameness of
  33. multiprocessing, we have to extract these outside of the
  34. function definition.
  35. The arguments are: S3 src bucket name, S3 key name, S3 dest
  36. bucket_name, MultiPartUpload id, the part number,
  37. part start position, part stop position
  38. """
  39. # Multiprocessing args lameness
  40. src_bucket_name, src_key_name, dest_bucket_name, mpu_id, part_num, start_pos, end_pos = args
  41. logger.debug("do_part_copy got args: %s" % (args,))
  42. # Connect to S3, get the MultiPartUpload
  43. s3 = boto.connect_s3(calling_format=OrdinaryCallingFormat())
  44. dest_bucket = s3.lookup(dest_bucket_name)
  45. mpu = None
  46. for mp in dest_bucket.list_multipart_uploads():
  47. if mp.id == mpu_id:
  48. mpu = mp
  49. break
  50. if mpu is None:
  51. raise Exception("Could not find MultiPartUpload %s" % mpu_id)
  52. # make sure we have a valid key
  53. src_bucket = s3.lookup( src_bucket_name )
  54. src_key = src_bucket.get_key( src_key_name )
  55. # Do the copy
  56. t1 = time.time()
  57. mpu.copy_part_from_key(src_bucket_name, src_key_name, part_num, start_pos, end_pos)
  58. # Print some timings
  59. t2 = time.time() - t1
  60. s = (end_pos - start_pos)/1024./1024.
  61. logger.info("Copied part %s (%0.2fM) in %0.2fs at %0.2fMbps" % (part_num, s, t2, s/t2))
  62. def validate_url( url ):
  63. split = urlparse.urlsplit( url )
  64. if split.scheme != "s3":
  65. raise ValueError("'%s' is not an S3 url" % url)
  66. return split.netloc, split.path[1:]
  67. def main(src, dest, num_processes=2, split=50, force=False, reduced_redundancy=False, verbose=False):
  68. dest_bucket_name, dest_key_name = validate_url( dest )
  69. src_bucket_name, src_key_name = validate_url( src )
  70. s3 = boto.connect_s3(calling_format=OrdinaryCallingFormat())
  71. dest_bucket = s3.lookup( dest_bucket_name )
  72. dest_key = dest_bucket.get_key( dest_key_name )
  73. # See if we're overwriting an existing key
  74. if dest_key is not None:
  75. if not force:
  76. raise ValueError("'%s' already exists. Specify -f to overwrite it" % dest)
  77. # Determine the total size and calculate byte ranges
  78. src_bucket = s3.lookup( src_bucket_name )
  79. src_key = src_bucket.get_key( src_key_name )
  80. size = src_key.size
  81. # If file is less than 5G, copy it directly
  82. if size < 5*1024*1024*1024:
  83. logging.info("Source object is %0.2fM copying it directly" % ( size/1024./1024. ))
  84. t1 = time.time()
  85. src_key.copy( dest_bucket_name, dest_key_name, reduced_redundancy=reduced_redundancy )
  86. t2 = time.time() - t1
  87. s = size/1024./1024.
  88. logger.info("Finished copying %0.2fM in %0.2fs (%0.2fMbps)" % (s, t2, s/t2))
  89. return
  90. part_size = max(5*1024*1024, 1024*1024*split)
  91. num_parts = int(ceil(size / float(part_size)))
  92. logging.info("Source object is %0.2fM splitting into %d parts of size %0.2fM" % (size/1024./1024., num_parts, part_size/1024./1024.) )
  93. # Create the multi-part upload object
  94. mpu = dest_bucket.initiate_multipart_upload( dest_key_name, reduced_redundancy=reduced_redundancy)
  95. logger.info("Initialized copy: %s" % mpu.id)
  96. # Generate arguments for invocations of do_part_copy
  97. def gen_args(num_parts):
  98. cur_pos = 0
  99. for i in range(num_parts):
  100. part_start = cur_pos
  101. cur_pos = cur_pos + part_size
  102. part_end = min(cur_pos - 1, size - 1)
  103. part_num = i + 1
  104. yield (src_bucket_name, src_key_name, dest_bucket_name, mpu.id, part_num, part_start, part_end)
  105. # Do the thing
  106. try:
  107. # Create a pool of workers
  108. pool = Pool(processes=num_processes)
  109. t1 = time.time()
  110. pool.map_async(do_part_copy, gen_args(num_parts)).get(9999999)
  111. # Print out some timings
  112. t2 = time.time() - t1
  113. s = size/1024./1024.
  114. # Finalize
  115. mpu.complete_upload()
  116. logger.info("Finished copying %0.2fM in %0.2fs (%0.2fMbps)" % (s, t2, s/t2))
  117. except KeyboardInterrupt:
  118. logger.warn("Received KeyboardInterrupt, canceling copy")
  119. pool.terminate()
  120. mpu.cancel_upload()
  121. except Exception, err:
  122. logger.error("Encountered an error, canceling copy")
  123. logger.error(err)
  124. mpu.cancel_upload()
  125. if __name__ == "__main__":
  126. logging.basicConfig(level=logging.INFO)
  127. args = parser.parse_args()
  128. arg_dict = vars(args)
  129. if arg_dict['verbose'] == True:
  130. logger.setLevel(logging.DEBUG)
  131. logger.debug("CLI args: %s" % args)
  132. main(**arg_dict)