resource.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. #!/usr/bin/env python3
  2. # coding=utf-8
  3. import hashlib
  4. import logging
  5. import re
  6. import time
  7. import urllib.parse
  8. from typing import Callable, Optional
  9. import _hashlib
  10. import requests
  11. from praw.models import Submission
  12. from bdfr.exceptions import BulkDownloaderException
  13. logger = logging.getLogger(__name__)
  14. class Resource:
  15. def __init__(self, source_submission: Submission, url: str, download_function: Callable, extension: str = None):
  16. self.source_submission = source_submission
  17. self.content: Optional[bytes] = None
  18. self.url = url
  19. self.hash: Optional[_hashlib.HASH] = None
  20. self.extension = extension
  21. self.download_function = download_function
  22. if not self.extension:
  23. self.extension = self._determine_extension()
  24. @staticmethod
  25. def retry_download(url: str) -> Callable:
  26. return lambda global_params: Resource.http_download(url, global_params)
  27. def download(self, download_parameters: Optional[dict] = None):
  28. if download_parameters is None:
  29. download_parameters = {}
  30. if not self.content:
  31. try:
  32. content = self.download_function(download_parameters)
  33. except requests.exceptions.ConnectionError as e:
  34. raise BulkDownloaderException(f'Could not download resource: {e}')
  35. except BulkDownloaderException:
  36. raise
  37. if content:
  38. self.content = content
  39. if not self.hash and self.content:
  40. self.create_hash()
  41. def create_hash(self):
  42. self.hash = hashlib.md5(self.content)
  43. def _determine_extension(self) -> Optional[str]:
  44. extension_pattern = re.compile(r'.*(\..{3,5})$')
  45. stripped_url = urllib.parse.urlsplit(self.url).path
  46. match = re.search(extension_pattern, stripped_url)
  47. if match:
  48. return match.group(1)
  49. @staticmethod
  50. def http_download(url: str, download_parameters: dict) -> Optional[bytes]:
  51. headers = download_parameters.get('headers')
  52. current_wait_time = 60
  53. if 'max_wait_time' in download_parameters:
  54. max_wait_time = download_parameters['max_wait_time']
  55. else:
  56. max_wait_time = 300
  57. while True:
  58. try:
  59. response = requests.get(url, headers=headers)
  60. if re.match(r'^2\d{2}', str(response.status_code)) and response.content:
  61. return response.content
  62. elif response.status_code in (408, 429):
  63. raise requests.exceptions.ConnectionError(f'Response code {response.status_code}')
  64. else:
  65. raise BulkDownloaderException(
  66. f'Unrecoverable error requesting resource: HTTP Code {response.status_code}')
  67. except (requests.exceptions.ConnectionError, requests.exceptions.ChunkedEncodingError) as e:
  68. logger.warning(f'Error occured downloading from {url}, waiting {current_wait_time} seconds: {e}')
  69. time.sleep(current_wait_time)
  70. if current_wait_time < max_wait_time:
  71. current_wait_time += 60
  72. else:
  73. logger.error(f'Max wait time exceeded for resource at url {url}')
  74. raise