test_connector.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. #!/usr/bin/env python3
  2. # coding=utf-8
  3. from datetime import datetime, timedelta
  4. from pathlib import Path
  5. from typing import Iterator
  6. from unittest.mock import MagicMock
  7. import praw
  8. import praw.models
  9. import pytest
  10. from bdfr.configuration import Configuration
  11. from bdfr.connector import RedditConnector, RedditTypes
  12. from bdfr.download_filter import DownloadFilter
  13. from bdfr.exceptions import BulkDownloaderException
  14. from bdfr.file_name_formatter import FileNameFormatter
  15. from bdfr.site_authenticator import SiteAuthenticator
  16. @pytest.fixture()
  17. def args() -> Configuration:
  18. args = Configuration()
  19. args.time_format = 'ISO'
  20. return args
  21. @pytest.fixture()
  22. def downloader_mock(args: Configuration):
  23. downloader_mock = MagicMock()
  24. downloader_mock.args = args
  25. downloader_mock.sanitise_subreddit_name = RedditConnector.sanitise_subreddit_name
  26. downloader_mock.create_filtered_listing_generator = lambda x: RedditConnector.create_filtered_listing_generator(
  27. downloader_mock, x)
  28. downloader_mock.split_args_input = RedditConnector.split_args_input
  29. downloader_mock.master_hash_list = {}
  30. return downloader_mock
  31. def assert_all_results_are_submissions(result_limit: int, results: list[Iterator]) -> list:
  32. results = [sub for res in results for sub in res]
  33. assert all([isinstance(res, praw.models.Submission) for res in results])
  34. assert not any([isinstance(m, MagicMock) for m in results])
  35. if result_limit is not None:
  36. assert len(results) == result_limit
  37. return results
  38. def assert_all_results_are_submissions_or_comments(result_limit: int, results: list[Iterator]) -> list:
  39. results = [sub for res in results for sub in res]
  40. assert all([isinstance(res, praw.models.Submission) or isinstance(res, praw.models.Comment) for res in results])
  41. assert not any([isinstance(m, MagicMock) for m in results])
  42. if result_limit is not None:
  43. assert len(results) == result_limit
  44. return results
  45. def test_determine_directories(tmp_path: Path, downloader_mock: MagicMock):
  46. downloader_mock.args.directory = tmp_path / 'test'
  47. downloader_mock.config_directories.user_config_dir = tmp_path
  48. RedditConnector.determine_directories(downloader_mock)
  49. assert Path(tmp_path / 'test').exists()
  50. @pytest.mark.parametrize(('skip_extensions', 'skip_domains'), (
  51. ([], []),
  52. (['.test'], ['test.com'],),
  53. ))
  54. def test_create_download_filter(skip_extensions: list[str], skip_domains: list[str], downloader_mock: MagicMock):
  55. downloader_mock.args.skip = skip_extensions
  56. downloader_mock.args.skip_domain = skip_domains
  57. result = RedditConnector.create_download_filter(downloader_mock)
  58. assert isinstance(result, DownloadFilter)
  59. assert result.excluded_domains == skip_domains
  60. assert result.excluded_extensions == skip_extensions
  61. @pytest.mark.parametrize(('test_time', 'expected'), (
  62. ('all', 'all'),
  63. ('hour', 'hour'),
  64. ('day', 'day'),
  65. ('week', 'week'),
  66. ('random', 'all'),
  67. ('', 'all'),
  68. ))
  69. def test_create_time_filter(test_time: str, expected: str, downloader_mock: MagicMock):
  70. downloader_mock.args.time = test_time
  71. result = RedditConnector.create_time_filter(downloader_mock)
  72. assert isinstance(result, RedditTypes.TimeType)
  73. assert result.name.lower() == expected
  74. @pytest.mark.parametrize(('test_sort', 'expected'), (
  75. ('', 'hot'),
  76. ('hot', 'hot'),
  77. ('controversial', 'controversial'),
  78. ('new', 'new'),
  79. ))
  80. def test_create_sort_filter(test_sort: str, expected: str, downloader_mock: MagicMock):
  81. downloader_mock.args.sort = test_sort
  82. result = RedditConnector.create_sort_filter(downloader_mock)
  83. assert isinstance(result, RedditTypes.SortType)
  84. assert result.name.lower() == expected
  85. @pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme'), (
  86. ('{POSTID}', '{SUBREDDIT}'),
  87. ('{REDDITOR}_{TITLE}_{POSTID}', '{SUBREDDIT}'),
  88. ('{POSTID}', 'test'),
  89. ('{POSTID}', ''),
  90. ('{POSTID}', '{SUBREDDIT}/{REDDITOR}'),
  91. ))
  92. def test_create_file_name_formatter(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock):
  93. downloader_mock.args.file_scheme = test_file_scheme
  94. downloader_mock.args.folder_scheme = test_folder_scheme
  95. result = RedditConnector.create_file_name_formatter(downloader_mock)
  96. assert isinstance(result, FileNameFormatter)
  97. assert result.file_format_string == test_file_scheme
  98. assert result.directory_format_string == test_folder_scheme.split('/')
  99. @pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme'), (
  100. ('', ''),
  101. ('', '{SUBREDDIT}'),
  102. ('test', '{SUBREDDIT}'),
  103. ))
  104. def test_create_file_name_formatter_bad(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock):
  105. downloader_mock.args.file_scheme = test_file_scheme
  106. downloader_mock.args.folder_scheme = test_folder_scheme
  107. with pytest.raises(BulkDownloaderException):
  108. RedditConnector.create_file_name_formatter(downloader_mock)
  109. def test_create_authenticator(downloader_mock: MagicMock):
  110. result = RedditConnector.create_authenticator(downloader_mock)
  111. assert isinstance(result, SiteAuthenticator)
  112. @pytest.mark.online
  113. @pytest.mark.reddit
  114. @pytest.mark.parametrize('test_submission_ids', (
  115. ('lvpf4l',),
  116. ('lvpf4l', 'lvqnsn'),
  117. ('lvpf4l', 'lvqnsn', 'lvl9kd'),
  118. ))
  119. def test_get_submissions_from_link(
  120. test_submission_ids: list[str],
  121. reddit_instance: praw.Reddit,
  122. downloader_mock: MagicMock):
  123. downloader_mock.args.link = test_submission_ids
  124. downloader_mock.reddit_instance = reddit_instance
  125. results = RedditConnector.get_submissions_from_link(downloader_mock)
  126. assert all([isinstance(sub, praw.models.Submission) for res in results for sub in res])
  127. assert len(results[0]) == len(test_submission_ids)
  128. @pytest.mark.online
  129. @pytest.mark.reddit
  130. @pytest.mark.parametrize(('test_subreddits', 'limit', 'sort_type', 'time_filter', 'max_expected_len'), (
  131. (('Futurology',), 10, 'hot', 'all', 10),
  132. (('Futurology', 'Mindustry, Python'), 10, 'hot', 'all', 30),
  133. (('Futurology',), 20, 'hot', 'all', 20),
  134. (('Futurology', 'Python'), 10, 'hot', 'all', 20),
  135. (('Futurology',), 100, 'hot', 'all', 100),
  136. (('Futurology',), 0, 'hot', 'all', 0),
  137. (('Futurology',), 10, 'top', 'all', 10),
  138. (('Futurology',), 10, 'top', 'week', 10),
  139. (('Futurology',), 10, 'hot', 'week', 10),
  140. ))
  141. def test_get_subreddit_normal(
  142. test_subreddits: list[str],
  143. limit: int,
  144. sort_type: str,
  145. time_filter: str,
  146. max_expected_len: int,
  147. downloader_mock: MagicMock,
  148. reddit_instance: praw.Reddit,
  149. ):
  150. downloader_mock.args.limit = limit
  151. downloader_mock.args.sort = sort_type
  152. downloader_mock.time_filter = RedditConnector.create_time_filter(downloader_mock)
  153. downloader_mock.sort_filter = RedditConnector.create_sort_filter(downloader_mock)
  154. downloader_mock.determine_sort_function.return_value = RedditConnector.determine_sort_function(downloader_mock)
  155. downloader_mock.args.subreddit = test_subreddits
  156. downloader_mock.reddit_instance = reddit_instance
  157. results = RedditConnector.get_subreddits(downloader_mock)
  158. test_subreddits = downloader_mock.split_args_input(test_subreddits)
  159. results = [sub for res1 in results for sub in res1]
  160. assert all([isinstance(res1, praw.models.Submission) for res1 in results])
  161. assert all([res.subreddit.display_name in test_subreddits for res in results])
  162. assert len(results) <= max_expected_len
  163. assert not any([isinstance(m, MagicMock) for m in results])
  164. @pytest.mark.online
  165. @pytest.mark.reddit
  166. @pytest.mark.parametrize(('test_time', 'test_delta'), (
  167. ('hour', timedelta(hours=1)),
  168. ('day', timedelta(days=1)),
  169. ('week', timedelta(days=7)),
  170. ('month', timedelta(days=31)),
  171. ('year', timedelta(days=365)),
  172. ))
  173. def test_get_subreddit_time_verification(
  174. test_time: str,
  175. test_delta: timedelta,
  176. downloader_mock: MagicMock,
  177. reddit_instance: praw.Reddit,
  178. ):
  179. downloader_mock.args.limit = 10
  180. downloader_mock.args.sort = 'top'
  181. downloader_mock.args.time = test_time
  182. downloader_mock.time_filter = RedditConnector.create_time_filter(downloader_mock)
  183. downloader_mock.sort_filter = RedditConnector.create_sort_filter(downloader_mock)
  184. downloader_mock.determine_sort_function.return_value = RedditConnector.determine_sort_function(downloader_mock)
  185. downloader_mock.args.subreddit = ['all']
  186. downloader_mock.reddit_instance = reddit_instance
  187. results = RedditConnector.get_subreddits(downloader_mock)
  188. results = [sub for res1 in results for sub in res1]
  189. assert all([isinstance(res1, praw.models.Submission) for res1 in results])
  190. nowtime = datetime.now()
  191. for r in results:
  192. result_time = datetime.fromtimestamp(r.created_utc)
  193. time_diff = nowtime - result_time
  194. assert time_diff < test_delta
  195. @pytest.mark.online
  196. @pytest.mark.reddit
  197. @pytest.mark.parametrize(('test_subreddits', 'search_term', 'limit', 'time_filter', 'max_expected_len'), (
  198. (('Python',), 'scraper', 10, 'all', 10),
  199. (('Python',), '', 10, 'all', 0),
  200. (('Python',), 'djsdsgewef', 10, 'all', 0),
  201. (('Python',), 'scraper', 10, 'year', 10),
  202. ))
  203. def test_get_subreddit_search(
  204. test_subreddits: list[str],
  205. search_term: str,
  206. time_filter: str,
  207. limit: int,
  208. max_expected_len: int,
  209. downloader_mock: MagicMock,
  210. reddit_instance: praw.Reddit,
  211. ):
  212. downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
  213. downloader_mock.args.limit = limit
  214. downloader_mock.args.search = search_term
  215. downloader_mock.args.subreddit = test_subreddits
  216. downloader_mock.reddit_instance = reddit_instance
  217. downloader_mock.sort_filter = RedditTypes.SortType.HOT
  218. downloader_mock.args.time = time_filter
  219. downloader_mock.time_filter = RedditConnector.create_time_filter(downloader_mock)
  220. results = RedditConnector.get_subreddits(downloader_mock)
  221. results = [sub for res in results for sub in res]
  222. assert all([isinstance(res, praw.models.Submission) for res in results])
  223. assert all([res.subreddit.display_name in test_subreddits for res in results])
  224. assert len(results) <= max_expected_len
  225. if max_expected_len != 0:
  226. assert len(results) > 0
  227. assert not any([isinstance(m, MagicMock) for m in results])
  228. @pytest.mark.online
  229. @pytest.mark.reddit
  230. @pytest.mark.parametrize(('test_user', 'test_multireddits', 'limit'), (
  231. ('helen_darten', ('cuteanimalpics',), 10),
  232. ('korfor', ('chess',), 100),
  233. ))
  234. # Good sources at https://www.reddit.com/r/multihub/
  235. def test_get_multireddits_public(
  236. test_user: str,
  237. test_multireddits: list[str],
  238. limit: int,
  239. reddit_instance: praw.Reddit,
  240. downloader_mock: MagicMock,
  241. ):
  242. downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot
  243. downloader_mock.sort_filter = RedditTypes.SortType.HOT
  244. downloader_mock.args.limit = limit
  245. downloader_mock.args.multireddit = test_multireddits
  246. downloader_mock.args.user = [test_user]
  247. downloader_mock.reddit_instance = reddit_instance
  248. downloader_mock.create_filtered_listing_generator.return_value = \
  249. RedditConnector.create_filtered_listing_generator(
  250. downloader_mock,
  251. reddit_instance.multireddit(test_user, test_multireddits[0]),
  252. )
  253. results = RedditConnector.get_multireddits(downloader_mock)
  254. results = [sub for res in results for sub in res]
  255. assert all([isinstance(res, praw.models.Submission) for res in results])
  256. assert len(results) == limit
  257. assert not any([isinstance(m, MagicMock) for m in results])
  258. @pytest.mark.online
  259. @pytest.mark.reddit
  260. @pytest.mark.parametrize(('test_user', 'limit'), (
  261. ('danigirl3694', 10),
  262. ('danigirl3694', 50),
  263. ('CapitanHam', None),
  264. ))
  265. def test_get_user_submissions(test_user: str, limit: int, downloader_mock: MagicMock, reddit_instance: praw.Reddit):
  266. downloader_mock.args.limit = limit
  267. downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot
  268. downloader_mock.sort_filter = RedditTypes.SortType.HOT
  269. downloader_mock.args.submitted = True
  270. downloader_mock.args.user = [test_user]
  271. downloader_mock.authenticated = False
  272. downloader_mock.reddit_instance = reddit_instance
  273. downloader_mock.create_filtered_listing_generator.return_value = \
  274. RedditConnector.create_filtered_listing_generator(
  275. downloader_mock,
  276. reddit_instance.redditor(test_user).submissions,
  277. )
  278. results = RedditConnector.get_user_data(downloader_mock)
  279. results = assert_all_results_are_submissions(limit, results)
  280. assert all([res.author.name == test_user for res in results])
  281. assert not any([isinstance(m, MagicMock) for m in results])
  282. @pytest.mark.online
  283. @pytest.mark.reddit
  284. @pytest.mark.authenticated
  285. @pytest.mark.parametrize('test_flag', (
  286. 'upvoted',
  287. 'saved',
  288. ))
  289. def test_get_user_authenticated_lists(
  290. test_flag: str,
  291. downloader_mock: MagicMock,
  292. authenticated_reddit_instance: praw.Reddit,
  293. ):
  294. downloader_mock.args.__dict__[test_flag] = True
  295. downloader_mock.reddit_instance = authenticated_reddit_instance
  296. downloader_mock.args.limit = 10
  297. downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot
  298. downloader_mock.sort_filter = RedditTypes.SortType.HOT
  299. downloader_mock.args.user = [RedditConnector.resolve_user_name(downloader_mock, 'me')]
  300. results = RedditConnector.get_user_data(downloader_mock)
  301. assert_all_results_are_submissions_or_comments(10, results)
  302. @pytest.mark.online
  303. @pytest.mark.reddit
  304. @pytest.mark.authenticated
  305. def test_get_subscribed_subreddits(downloader_mock: MagicMock, authenticated_reddit_instance: praw.Reddit):
  306. downloader_mock.reddit_instance = authenticated_reddit_instance
  307. downloader_mock.args.limit = 10
  308. downloader_mock.args.authenticate = True
  309. downloader_mock.args.subscribed = True
  310. downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot
  311. downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot
  312. downloader_mock.sort_filter = RedditTypes.SortType.HOT
  313. results = RedditConnector.get_subreddits(downloader_mock)
  314. assert all([isinstance(s, praw.models.ListingGenerator) for s in results])
  315. assert len(results) > 0
  316. @pytest.mark.parametrize(('test_name', 'expected'), (
  317. ('Mindustry', 'Mindustry'),
  318. ('Futurology', 'Futurology'),
  319. ('r/Mindustry', 'Mindustry'),
  320. ('TrollXChromosomes', 'TrollXChromosomes'),
  321. ('r/TrollXChromosomes', 'TrollXChromosomes'),
  322. ('https://www.reddit.com/r/TrollXChromosomes/', 'TrollXChromosomes'),
  323. ('https://www.reddit.com/r/TrollXChromosomes', 'TrollXChromosomes'),
  324. ('https://www.reddit.com/r/Futurology/', 'Futurology'),
  325. ('https://www.reddit.com/r/Futurology', 'Futurology'),
  326. ))
  327. def test_sanitise_subreddit_name(test_name: str, expected: str):
  328. result = RedditConnector.sanitise_subreddit_name(test_name)
  329. assert result == expected
  330. @pytest.mark.parametrize(('test_subreddit_entries', 'expected'), (
  331. (['test1', 'test2', 'test3'], {'test1', 'test2', 'test3'}),
  332. (['test1,test2', 'test3'], {'test1', 'test2', 'test3'}),
  333. (['test1, test2', 'test3'], {'test1', 'test2', 'test3'}),
  334. (['test1; test2', 'test3'], {'test1', 'test2', 'test3'}),
  335. (['test1, test2', 'test1,test2,test3', 'test4'], {'test1', 'test2', 'test3', 'test4'}),
  336. ([''], {''}),
  337. (['test'], {'test'}),
  338. ))
  339. def test_split_subreddit_entries(test_subreddit_entries: list[str], expected: set[str]):
  340. results = RedditConnector.split_args_input(test_subreddit_entries)
  341. assert results == expected
  342. def test_read_submission_ids_from_file(downloader_mock: MagicMock, tmp_path: Path):
  343. test_file = tmp_path / 'test.txt'
  344. test_file.write_text('aaaaaa\nbbbbbb')
  345. results = RedditConnector.read_id_files([str(test_file)])
  346. assert results == {'aaaaaa', 'bbbbbb'}
  347. @pytest.mark.online
  348. @pytest.mark.reddit
  349. @pytest.mark.parametrize('test_redditor_name', (
  350. 'nasa',
  351. 'crowdstrike',
  352. 'HannibalGoddamnit',
  353. ))
  354. def test_check_user_existence_good(
  355. test_redditor_name: str,
  356. reddit_instance: praw.Reddit,
  357. downloader_mock: MagicMock,
  358. ):
  359. downloader_mock.reddit_instance = reddit_instance
  360. RedditConnector.check_user_existence(downloader_mock, test_redditor_name)
  361. @pytest.mark.online
  362. @pytest.mark.reddit
  363. @pytest.mark.parametrize('test_redditor_name', (
  364. 'lhnhfkuhwreolo',
  365. 'adlkfmnhglojh',
  366. ))
  367. def test_check_user_existence_nonexistent(
  368. test_redditor_name: str,
  369. reddit_instance: praw.Reddit,
  370. downloader_mock: MagicMock,
  371. ):
  372. downloader_mock.reddit_instance = reddit_instance
  373. with pytest.raises(BulkDownloaderException, match='Could not find'):
  374. RedditConnector.check_user_existence(downloader_mock, test_redditor_name)
  375. @pytest.mark.online
  376. @pytest.mark.reddit
  377. @pytest.mark.parametrize('test_redditor_name', (
  378. 'Bree-Boo',
  379. ))
  380. def test_check_user_existence_banned(
  381. test_redditor_name: str,
  382. reddit_instance: praw.Reddit,
  383. downloader_mock: MagicMock,
  384. ):
  385. downloader_mock.reddit_instance = reddit_instance
  386. with pytest.raises(BulkDownloaderException, match='is banned'):
  387. RedditConnector.check_user_existence(downloader_mock, test_redditor_name)
  388. @pytest.mark.online
  389. @pytest.mark.reddit
  390. @pytest.mark.parametrize(('test_subreddit_name', 'expected_message'), (
  391. ('donaldtrump', 'cannot be found'),
  392. ('submitters', 'private and cannot be scraped'),
  393. ('lhnhfkuhwreolo', 'does not exist')
  394. ))
  395. def test_check_subreddit_status_bad(test_subreddit_name: str, expected_message: str, reddit_instance: praw.Reddit):
  396. test_subreddit = reddit_instance.subreddit(test_subreddit_name)
  397. with pytest.raises(BulkDownloaderException, match=expected_message):
  398. RedditConnector.check_subreddit_status(test_subreddit)
  399. @pytest.mark.online
  400. @pytest.mark.reddit
  401. @pytest.mark.parametrize('test_subreddit_name', (
  402. 'Python',
  403. 'Mindustry',
  404. 'TrollXChromosomes',
  405. 'all',
  406. ))
  407. def test_check_subreddit_status_good(test_subreddit_name: str, reddit_instance: praw.Reddit):
  408. test_subreddit = reddit_instance.subreddit(test_subreddit_name)
  409. RedditConnector.check_subreddit_status(test_subreddit)