s3-mp-download_3.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. def main(src, dest, num_processes=2, split=32, force=False, verbose=False, quiet=False, secure=True, max_tries=5):
  2. # Check that src is a valid S3 url
  3. split_rs = urlparse.urlsplit(src)
  4. if split_rs.scheme != "s3":
  5. raise ValueError("'%s' is not an S3 url" % src)
  6. # Check that dest does not exist
  7. if os.path.isdir(dest):
  8. filename = split_rs.path.split('/')[-1]
  9. dest = os.path.join(dest, filename)
  10. if os.path.exists(dest):
  11. if force:
  12. os.remove(dest)
  13. else:
  14. raise ValueError("Destination file '%s' exists, specify -f to"
  15. " overwrite" % dest)
  16. # Split out the bucket and the key
  17. s3 = boto.connect_s3()
  18. s3 = boto.connect_s3(calling_format=OrdinaryCallingFormat())
  19. s3.is_secure = secure
  20. logger.debug("split_rs: %s" % str(split_rs))
  21. bucket = s3.lookup(split_rs.netloc)
  22. if bucket == None:
  23. raise ValueError("'%s' is not a valid bucket" % split_rs.netloc)
  24. key = bucket.get_key(split_rs.path)
  25. if key is None:
  26. raise ValueError("'%s' does not exist." % split_rs.path)
  27. # Determine the total size and calculate byte ranges
  28. resp = s3.make_request("HEAD", bucket=bucket, key=key)
  29. if resp is None:
  30. raise ValueError("response is invalid.")
  31. size = int(resp.getheader("content-length"))
  32. logger.debug("Got headers: %s" % resp.getheaders())
  33. # Skipping multipart if file is less than 1mb
  34. if size < 1024 * 1024:
  35. t1 = time.time()
  36. key.get_contents_to_filename(dest)
  37. t2 = time.time() - t1
  38. size_mb = size / 1024 / 1024
  39. logger.info("Finished single-part download of %0.2fM in %0.2fs (%0.2fMBps)" %
  40. (size_mb, t2, size_mb/t2))
  41. else:
  42. # Touch the file
  43. fd = os.open(dest, os.O_CREAT)
  44. os.close(fd)
  45. size_mb = size / 1024 / 1024
  46. num_parts = (size_mb+(-size_mb%split))//split
  47. def arg_iterator(num_parts):
  48. for min_byte, max_byte in gen_byte_ranges(size, num_parts):
  49. yield (bucket.name, key.name, dest, min_byte, max_byte, split, secure, max_tries, 0)
  50. s = size / 1024 / 1024.
  51. try:
  52. t1 = time.time()
  53. pool = Pool(processes=num_processes)
  54. pool.map_async(do_part_download, arg_iterator(num_parts)).get(9999999)
  55. t2 = time.time() - t1
  56. logger.info("Finished downloading %0.2fM in %0.2fs (%0.2fMBps)" %
  57. (s, t2, s/t2))
  58. except KeyboardInterrupt:
  59. logger.warning("User terminated")
  60. except Exception as err:
  61. logger.error(err)