file_name_formatter.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. #!/usr/bin/env python3
  2. # coding=utf-8
  3. import datetime
  4. import logging
  5. import platform
  6. import re
  7. import subprocess
  8. from pathlib import Path
  9. from typing import Optional
  10. from praw.models import Comment, Submission
  11. from bdfr.exceptions import BulkDownloaderException
  12. from bdfr.resource import Resource
  13. logger = logging.getLogger(__name__)
  14. class FileNameFormatter:
  15. key_terms = (
  16. 'date',
  17. 'flair',
  18. 'postid',
  19. 'redditor',
  20. 'subreddit',
  21. 'title',
  22. 'upvotes',
  23. )
  24. def __init__(self, file_format_string: str, directory_format_string: str, time_format_string: str):
  25. if not self.validate_string(file_format_string):
  26. raise BulkDownloaderException(f'"{file_format_string}" is not a valid format string')
  27. self.file_format_string = file_format_string
  28. self.directory_format_string: list[str] = directory_format_string.split('/')
  29. self.time_format_string = time_format_string
  30. def _format_name(self, submission: (Comment, Submission), format_string: str) -> str:
  31. if isinstance(submission, Submission):
  32. attributes = self._generate_name_dict_from_submission(submission)
  33. elif isinstance(submission, Comment):
  34. attributes = self._generate_name_dict_from_comment(submission)
  35. else:
  36. raise BulkDownloaderException(f'Cannot name object {type(submission).__name__}')
  37. result = format_string
  38. for key in attributes.keys():
  39. if re.search(fr'(?i).*{{{key}}}.*', result):
  40. key_value = str(attributes.get(key, 'unknown'))
  41. key_value = FileNameFormatter._convert_unicode_escapes(key_value)
  42. key_value = key_value.replace('\\', '\\\\')
  43. result = re.sub(fr'(?i){{{key}}}', key_value, result)
  44. result = result.replace('/', '')
  45. if platform.system() == 'Windows':
  46. result = FileNameFormatter._format_for_windows(result)
  47. return result
  48. @staticmethod
  49. def _convert_unicode_escapes(in_string: str) -> str:
  50. pattern = re.compile(r'(\\u\d{4})')
  51. matches = re.search(pattern, in_string)
  52. if matches:
  53. for match in matches.groups():
  54. converted_match = bytes(match, 'utf-8').decode('unicode-escape')
  55. in_string = in_string.replace(match, converted_match)
  56. return in_string
  57. def _generate_name_dict_from_submission(self, submission: Submission) -> dict:
  58. submission_attributes = {
  59. 'title': submission.title,
  60. 'subreddit': submission.subreddit.display_name,
  61. 'redditor': submission.author.name if submission.author else 'DELETED',
  62. 'postid': submission.id,
  63. 'upvotes': submission.score,
  64. 'flair': submission.link_flair_text,
  65. 'date': self._convert_timestamp(submission.created_utc),
  66. }
  67. return submission_attributes
  68. def _convert_timestamp(self, timestamp: float) -> str:
  69. input_time = datetime.datetime.fromtimestamp(timestamp)
  70. if self.time_format_string.upper().strip() == 'ISO':
  71. return input_time.isoformat()
  72. else:
  73. return input_time.strftime(self.time_format_string)
  74. def _generate_name_dict_from_comment(self, comment: Comment) -> dict:
  75. comment_attributes = {
  76. 'title': comment.submission.title,
  77. 'subreddit': comment.subreddit.display_name,
  78. 'redditor': comment.author.name if comment.author else 'DELETED',
  79. 'postid': comment.id,
  80. 'upvotes': comment.score,
  81. 'flair': '',
  82. 'date': self._convert_timestamp(comment.created_utc),
  83. }
  84. return comment_attributes
  85. def format_path(
  86. self,
  87. resource: Resource,
  88. destination_directory: Path,
  89. index: Optional[int] = None,
  90. ) -> Path:
  91. subfolder = Path(
  92. destination_directory,
  93. *[self._format_name(resource.source_submission, part) for part in self.directory_format_string],
  94. )
  95. index = f'_{str(index)}' if index else ''
  96. if not resource.extension:
  97. raise BulkDownloaderException(f'Resource from {resource.url} has no extension')
  98. file_name = str(self._format_name(resource.source_submission, self.file_format_string))
  99. file_name = re.sub(r'\n', ' ', file_name)
  100. if not re.match(r'.*\.$', file_name) and not re.match(r'^\..*', resource.extension):
  101. ending = index + '.' + resource.extension
  102. else:
  103. ending = index + resource.extension
  104. try:
  105. file_path = self.limit_file_name_length(file_name, ending, subfolder)
  106. except TypeError:
  107. raise BulkDownloaderException(f'Could not determine path name: {subfolder}, {index}, {resource.extension}')
  108. return file_path
  109. @staticmethod
  110. def limit_file_name_length(filename: str, ending: str, root: Path) -> Path:
  111. root = root.resolve().expanduser()
  112. possible_id = re.search(r'((?:_\w{6})?$)', filename)
  113. if possible_id:
  114. ending = possible_id.group(1) + ending
  115. filename = filename[:possible_id.start()]
  116. max_path = FileNameFormatter.find_max_path_length()
  117. max_file_part_length_chars = 255 - len(ending)
  118. max_file_part_length_bytes = 255 - len(ending.encode('utf-8'))
  119. max_path_length = max_path - len(ending) - len(str(root)) - 1
  120. out = Path(root, filename + ending)
  121. while any([len(filename) > max_file_part_length_chars,
  122. len(filename.encode('utf-8')) > max_file_part_length_bytes,
  123. len(str(out)) > max_path_length,
  124. ]):
  125. filename = filename[:-1]
  126. out = Path(root, filename + ending)
  127. return out
  128. @staticmethod
  129. def find_max_path_length() -> int:
  130. try:
  131. return int(subprocess.check_output(['getconf', 'PATH_MAX', '/']))
  132. except (ValueError, subprocess.CalledProcessError, OSError):
  133. if platform.system() == 'Windows':
  134. return 260
  135. else:
  136. return 4096
  137. def format_resource_paths(
  138. self,
  139. resources: list[Resource],
  140. destination_directory: Path,
  141. ) -> list[tuple[Path, Resource]]:
  142. out = []
  143. if len(resources) == 1:
  144. try:
  145. out.append((self.format_path(resources[0], destination_directory, None), resources[0]))
  146. except BulkDownloaderException as e:
  147. logger.error(f'Could not generate file path for resource {resources[0].url}: {e}')
  148. logger.exception('Could not generate file path')
  149. else:
  150. for i, res in enumerate(resources, start=1):
  151. logger.log(9, f'Formatting filename with index {i}')
  152. try:
  153. out.append((self.format_path(res, destination_directory, i), res))
  154. except BulkDownloaderException as e:
  155. logger.error(f'Could not generate file path for resource {res.url}: {e}')
  156. logger.exception('Could not generate file path')
  157. return out
  158. @staticmethod
  159. def validate_string(test_string: str) -> bool:
  160. if not test_string:
  161. return False
  162. result = any([f'{{{key}}}' in test_string.lower() for key in FileNameFormatter.key_terms])
  163. if result:
  164. if 'POSTID' not in test_string:
  165. logger.warning('Some files might not be downloaded due to name conflicts as filenames are'
  166. ' not guaranteed to be be unique without {POSTID}')
  167. return True
  168. else:
  169. return False
  170. @staticmethod
  171. def _format_for_windows(input_string: str) -> str:
  172. invalid_characters = r'<>:"\/|?*'
  173. for char in invalid_characters:
  174. input_string = input_string.replace(char, '')
  175. input_string = FileNameFormatter._strip_emojis(input_string)
  176. return input_string
  177. @staticmethod
  178. def _strip_emojis(input_string: str) -> str:
  179. result = input_string.encode('ascii', errors='ignore').decode('utf-8')
  180. return result