123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469 |
- #!/usr/bin/env python3
- # coding=utf-8
- from datetime import datetime, timedelta
- from pathlib import Path
- from typing import Iterator
- from unittest.mock import MagicMock
- import praw
- import praw.models
- import pytest
- from bdfr.configuration import Configuration
- from bdfr.connector import RedditConnector, RedditTypes
- from bdfr.download_filter import DownloadFilter
- from bdfr.exceptions import BulkDownloaderException
- from bdfr.file_name_formatter import FileNameFormatter
- from bdfr.site_authenticator import SiteAuthenticator
- @pytest.fixture()
- def args() -> Configuration:
- args = Configuration()
- args.time_format = 'ISO'
- return args
- @pytest.fixture()
- def downloader_mock(args: Configuration):
- downloader_mock = MagicMock()
- downloader_mock.args = args
- downloader_mock.sanitise_subreddit_name = RedditConnector.sanitise_subreddit_name
- downloader_mock.create_filtered_listing_generator = lambda x: RedditConnector.create_filtered_listing_generator(
- downloader_mock, x)
- downloader_mock.split_args_input = RedditConnector.split_args_input
- downloader_mock.master_hash_list = {}
- return downloader_mock
- def assert_all_results_are_submissions(result_limit: int, results: list[Iterator]) -> list:
- results = [sub for res in results for sub in res]
- assert all([isinstance(res, praw.models.Submission) for res in results])
- assert not any([isinstance(m, MagicMock) for m in results])
- if result_limit is not None:
- assert len(results) == result_limit
- return results
- def assert_all_results_are_submissions_or_comments(result_limit: int, results: list[Iterator]) -> list:
- results = [sub for res in results for sub in res]
- assert all([isinstance(res, praw.models.Submission) or isinstance(res, praw.models.Comment) for res in results])
- assert not any([isinstance(m, MagicMock) for m in results])
- if result_limit is not None:
- assert len(results) == result_limit
- return results
- def test_determine_directories(tmp_path: Path, downloader_mock: MagicMock):
- downloader_mock.args.directory = tmp_path / 'test'
- downloader_mock.config_directories.user_config_dir = tmp_path
- RedditConnector.determine_directories(downloader_mock)
- assert Path(tmp_path / 'test').exists()
- @pytest.mark.parametrize(('skip_extensions', 'skip_domains'), (
- ([], []),
- (['.test'], ['test.com'],),
- ))
- def test_create_download_filter(skip_extensions: list[str], skip_domains: list[str], downloader_mock: MagicMock):
- downloader_mock.args.skip = skip_extensions
- downloader_mock.args.skip_domain = skip_domains
- result = RedditConnector.create_download_filter(downloader_mock)
- assert isinstance(result, DownloadFilter)
- assert result.excluded_domains == skip_domains
- assert result.excluded_extensions == skip_extensions
- @pytest.mark.parametrize(('test_time', 'expected'), (
- ('all', 'all'),
- ('hour', 'hour'),
- ('day', 'day'),
- ('week', 'week'),
- ('random', 'all'),
- ('', 'all'),
- ))
- def test_create_time_filter(test_time: str, expected: str, downloader_mock: MagicMock):
- downloader_mock.args.time = test_time
- result = RedditConnector.create_time_filter(downloader_mock)
- assert isinstance(result, RedditTypes.TimeType)
- assert result.name.lower() == expected
- @pytest.mark.parametrize(('test_sort', 'expected'), (
- ('', 'hot'),
- ('hot', 'hot'),
- ('controversial', 'controversial'),
- ('new', 'new'),
- ))
- def test_create_sort_filter(test_sort: str, expected: str, downloader_mock: MagicMock):
- downloader_mock.args.sort = test_sort
- result = RedditConnector.create_sort_filter(downloader_mock)
- assert isinstance(result, RedditTypes.SortType)
- assert result.name.lower() == expected
- @pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme'), (
- ('{POSTID}', '{SUBREDDIT}'),
- ('{REDDITOR}_{TITLE}_{POSTID}', '{SUBREDDIT}'),
- ('{POSTID}', 'test'),
- ('{POSTID}', ''),
- ('{POSTID}', '{SUBREDDIT}/{REDDITOR}'),
- ))
- def test_create_file_name_formatter(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock):
- downloader_mock.args.file_scheme = test_file_scheme
- downloader_mock.args.folder_scheme = test_folder_scheme
- result = RedditConnector.create_file_name_formatter(downloader_mock)
- assert isinstance(result, FileNameFormatter)
- assert result.file_format_string == test_file_scheme
- assert result.directory_format_string == test_folder_scheme.split('/')
- @pytest.mark.parametrize(('test_file_scheme', 'test_folder_scheme'), (
- ('', ''),
- ('', '{SUBREDDIT}'),
- ('test', '{SUBREDDIT}'),
- ))
- def test_create_file_name_formatter_bad(test_file_scheme: str, test_folder_scheme: str, downloader_mock: MagicMock):
- downloader_mock.args.file_scheme = test_file_scheme
- downloader_mock.args.folder_scheme = test_folder_scheme
- with pytest.raises(BulkDownloaderException):
- RedditConnector.create_file_name_formatter(downloader_mock)
- def test_create_authenticator(downloader_mock: MagicMock):
- result = RedditConnector.create_authenticator(downloader_mock)
- assert isinstance(result, SiteAuthenticator)
- @pytest.mark.online
- @pytest.mark.reddit
- @pytest.mark.parametrize('test_submission_ids', (
- ('lvpf4l',),
- ('lvpf4l', 'lvqnsn'),
- ('lvpf4l', 'lvqnsn', 'lvl9kd'),
- ))
- def test_get_submissions_from_link(
- test_submission_ids: list[str],
- reddit_instance: praw.Reddit,
- downloader_mock: MagicMock):
- downloader_mock.args.link = test_submission_ids
- downloader_mock.reddit_instance = reddit_instance
- results = RedditConnector.get_submissions_from_link(downloader_mock)
- assert all([isinstance(sub, praw.models.Submission) for res in results for sub in res])
- assert len(results[0]) == len(test_submission_ids)
- @pytest.mark.online
- @pytest.mark.reddit
- @pytest.mark.parametrize(('test_subreddits', 'limit', 'sort_type', 'time_filter', 'max_expected_len'), (
- (('Futurology',), 10, 'hot', 'all', 10),
- (('Futurology', 'Mindustry, Python'), 10, 'hot', 'all', 30),
- (('Futurology',), 20, 'hot', 'all', 20),
- (('Futurology', 'Python'), 10, 'hot', 'all', 20),
- (('Futurology',), 100, 'hot', 'all', 100),
- (('Futurology',), 0, 'hot', 'all', 0),
- (('Futurology',), 10, 'top', 'all', 10),
- (('Futurology',), 10, 'top', 'week', 10),
- (('Futurology',), 10, 'hot', 'week', 10),
- ))
- def test_get_subreddit_normal(
- test_subreddits: list[str],
- limit: int,
- sort_type: str,
- time_filter: str,
- max_expected_len: int,
- downloader_mock: MagicMock,
- reddit_instance: praw.Reddit,
- ):
- downloader_mock.args.limit = limit
- downloader_mock.args.sort = sort_type
- downloader_mock.time_filter = RedditConnector.create_time_filter(downloader_mock)
- downloader_mock.sort_filter = RedditConnector.create_sort_filter(downloader_mock)
- downloader_mock.determine_sort_function.return_value = RedditConnector.determine_sort_function(downloader_mock)
- downloader_mock.args.subreddit = test_subreddits
- downloader_mock.reddit_instance = reddit_instance
- results = RedditConnector.get_subreddits(downloader_mock)
- test_subreddits = downloader_mock.split_args_input(test_subreddits)
- results = [sub for res1 in results for sub in res1]
- assert all([isinstance(res1, praw.models.Submission) for res1 in results])
- assert all([res.subreddit.display_name in test_subreddits for res in results])
- assert len(results) <= max_expected_len
- assert not any([isinstance(m, MagicMock) for m in results])
- @pytest.mark.online
- @pytest.mark.reddit
- @pytest.mark.parametrize(('test_time', 'test_delta'), (
- ('hour', timedelta(hours=1)),
- ('day', timedelta(days=1)),
- ('week', timedelta(days=7)),
- ('month', timedelta(days=31)),
- ('year', timedelta(days=365)),
- ))
- def test_get_subreddit_time_verification(
- test_time: str,
- test_delta: timedelta,
- downloader_mock: MagicMock,
- reddit_instance: praw.Reddit,
- ):
- downloader_mock.args.limit = 10
- downloader_mock.args.sort = 'top'
- downloader_mock.args.time = test_time
- downloader_mock.time_filter = RedditConnector.create_time_filter(downloader_mock)
- downloader_mock.sort_filter = RedditConnector.create_sort_filter(downloader_mock)
- downloader_mock.determine_sort_function.return_value = RedditConnector.determine_sort_function(downloader_mock)
- downloader_mock.args.subreddit = ['all']
- downloader_mock.reddit_instance = reddit_instance
- results = RedditConnector.get_subreddits(downloader_mock)
- results = [sub for res1 in results for sub in res1]
- assert all([isinstance(res1, praw.models.Submission) for res1 in results])
- nowtime = datetime.now()
- for r in results:
- result_time = datetime.fromtimestamp(r.created_utc)
- time_diff = nowtime - result_time
- assert time_diff < test_delta
- @pytest.mark.online
- @pytest.mark.reddit
- @pytest.mark.parametrize(('test_subreddits', 'search_term', 'limit', 'time_filter', 'max_expected_len'), (
- (('Python',), 'scraper', 10, 'all', 10),
- (('Python',), '', 10, 'all', 0),
- (('Python',), 'djsdsgewef', 10, 'all', 0),
- (('Python',), 'scraper', 10, 'year', 10),
- ))
- def test_get_subreddit_search(
- test_subreddits: list[str],
- search_term: str,
- time_filter: str,
- limit: int,
- max_expected_len: int,
- downloader_mock: MagicMock,
- reddit_instance: praw.Reddit,
- ):
- downloader_mock._determine_sort_function.return_value = praw.models.Subreddit.hot
- downloader_mock.args.limit = limit
- downloader_mock.args.search = search_term
- downloader_mock.args.subreddit = test_subreddits
- downloader_mock.reddit_instance = reddit_instance
- downloader_mock.sort_filter = RedditTypes.SortType.HOT
- downloader_mock.args.time = time_filter
- downloader_mock.time_filter = RedditConnector.create_time_filter(downloader_mock)
- results = RedditConnector.get_subreddits(downloader_mock)
- results = [sub for res in results for sub in res]
- assert all([isinstance(res, praw.models.Submission) for res in results])
- assert all([res.subreddit.display_name in test_subreddits for res in results])
- assert len(results) <= max_expected_len
- if max_expected_len != 0:
- assert len(results) > 0
- assert not any([isinstance(m, MagicMock) for m in results])
- @pytest.mark.online
- @pytest.mark.reddit
- @pytest.mark.parametrize(('test_user', 'test_multireddits', 'limit'), (
- ('helen_darten', ('cuteanimalpics',), 10),
- ('korfor', ('chess',), 100),
- ))
- # Good sources at https://www.reddit.com/r/multihub/
- def test_get_multireddits_public(
- test_user: str,
- test_multireddits: list[str],
- limit: int,
- reddit_instance: praw.Reddit,
- downloader_mock: MagicMock,
- ):
- downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot
- downloader_mock.sort_filter = RedditTypes.SortType.HOT
- downloader_mock.args.limit = limit
- downloader_mock.args.multireddit = test_multireddits
- downloader_mock.args.user = [test_user]
- downloader_mock.reddit_instance = reddit_instance
- downloader_mock.create_filtered_listing_generator.return_value = \
- RedditConnector.create_filtered_listing_generator(
- downloader_mock,
- reddit_instance.multireddit(test_user, test_multireddits[0]),
- )
- results = RedditConnector.get_multireddits(downloader_mock)
- results = [sub for res in results for sub in res]
- assert all([isinstance(res, praw.models.Submission) for res in results])
- assert len(results) == limit
- assert not any([isinstance(m, MagicMock) for m in results])
- @pytest.mark.online
- @pytest.mark.reddit
- @pytest.mark.parametrize(('test_user', 'limit'), (
- ('danigirl3694', 10),
- ('danigirl3694', 50),
- ('CapitanHam', None),
- ))
- def test_get_user_submissions(test_user: str, limit: int, downloader_mock: MagicMock, reddit_instance: praw.Reddit):
- downloader_mock.args.limit = limit
- downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot
- downloader_mock.sort_filter = RedditTypes.SortType.HOT
- downloader_mock.args.submitted = True
- downloader_mock.args.user = [test_user]
- downloader_mock.authenticated = False
- downloader_mock.reddit_instance = reddit_instance
- downloader_mock.create_filtered_listing_generator.return_value = \
- RedditConnector.create_filtered_listing_generator(
- downloader_mock,
- reddit_instance.redditor(test_user).submissions,
- )
- results = RedditConnector.get_user_data(downloader_mock)
- results = assert_all_results_are_submissions(limit, results)
- assert all([res.author.name == test_user for res in results])
- assert not any([isinstance(m, MagicMock) for m in results])
- @pytest.mark.online
- @pytest.mark.reddit
- @pytest.mark.authenticated
- @pytest.mark.parametrize('test_flag', (
- 'upvoted',
- 'saved',
- ))
- def test_get_user_authenticated_lists(
- test_flag: str,
- downloader_mock: MagicMock,
- authenticated_reddit_instance: praw.Reddit,
- ):
- downloader_mock.args.__dict__[test_flag] = True
- downloader_mock.reddit_instance = authenticated_reddit_instance
- downloader_mock.args.limit = 10
- downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot
- downloader_mock.sort_filter = RedditTypes.SortType.HOT
- downloader_mock.args.user = [RedditConnector.resolve_user_name(downloader_mock, 'me')]
- results = RedditConnector.get_user_data(downloader_mock)
- assert_all_results_are_submissions_or_comments(10, results)
- @pytest.mark.online
- @pytest.mark.reddit
- @pytest.mark.authenticated
- def test_get_subscribed_subreddits(downloader_mock: MagicMock, authenticated_reddit_instance: praw.Reddit):
- downloader_mock.reddit_instance = authenticated_reddit_instance
- downloader_mock.args.limit = 10
- downloader_mock.args.authenticate = True
- downloader_mock.args.subscribed = True
- downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot
- downloader_mock.determine_sort_function.return_value = praw.models.Subreddit.hot
- downloader_mock.sort_filter = RedditTypes.SortType.HOT
- results = RedditConnector.get_subreddits(downloader_mock)
- assert all([isinstance(s, praw.models.ListingGenerator) for s in results])
- assert len(results) > 0
- @pytest.mark.parametrize(('test_name', 'expected'), (
- ('Mindustry', 'Mindustry'),
- ('Futurology', 'Futurology'),
- ('r/Mindustry', 'Mindustry'),
- ('TrollXChromosomes', 'TrollXChromosomes'),
- ('r/TrollXChromosomes', 'TrollXChromosomes'),
- ('https://www.reddit.com/r/TrollXChromosomes/', 'TrollXChromosomes'),
- ('https://www.reddit.com/r/TrollXChromosomes', 'TrollXChromosomes'),
- ('https://www.reddit.com/r/Futurology/', 'Futurology'),
- ('https://www.reddit.com/r/Futurology', 'Futurology'),
- ))
- def test_sanitise_subreddit_name(test_name: str, expected: str):
- result = RedditConnector.sanitise_subreddit_name(test_name)
- assert result == expected
- @pytest.mark.parametrize(('test_subreddit_entries', 'expected'), (
- (['test1', 'test2', 'test3'], {'test1', 'test2', 'test3'}),
- (['test1,test2', 'test3'], {'test1', 'test2', 'test3'}),
- (['test1, test2', 'test3'], {'test1', 'test2', 'test3'}),
- (['test1; test2', 'test3'], {'test1', 'test2', 'test3'}),
- (['test1, test2', 'test1,test2,test3', 'test4'], {'test1', 'test2', 'test3', 'test4'}),
- ([''], {''}),
- (['test'], {'test'}),
- ))
- def test_split_subreddit_entries(test_subreddit_entries: list[str], expected: set[str]):
- results = RedditConnector.split_args_input(test_subreddit_entries)
- assert results == expected
- def test_read_submission_ids_from_file(downloader_mock: MagicMock, tmp_path: Path):
- test_file = tmp_path / 'test.txt'
- test_file.write_text('aaaaaa\nbbbbbb')
- results = RedditConnector.read_id_files([str(test_file)])
- assert results == {'aaaaaa', 'bbbbbb'}
- @pytest.mark.online
- @pytest.mark.reddit
- @pytest.mark.parametrize('test_redditor_name', (
- 'nasa',
- 'crowdstrike',
- 'HannibalGoddamnit',
- ))
- def test_check_user_existence_good(
- test_redditor_name: str,
- reddit_instance: praw.Reddit,
- downloader_mock: MagicMock,
- ):
- downloader_mock.reddit_instance = reddit_instance
- RedditConnector.check_user_existence(downloader_mock, test_redditor_name)
- @pytest.mark.online
- @pytest.mark.reddit
- @pytest.mark.parametrize('test_redditor_name', (
- 'lhnhfkuhwreolo',
- 'adlkfmnhglojh',
- ))
- def test_check_user_existence_nonexistent(
- test_redditor_name: str,
- reddit_instance: praw.Reddit,
- downloader_mock: MagicMock,
- ):
- downloader_mock.reddit_instance = reddit_instance
- with pytest.raises(BulkDownloaderException, match='Could not find'):
- RedditConnector.check_user_existence(downloader_mock, test_redditor_name)
- @pytest.mark.online
- @pytest.mark.reddit
- @pytest.mark.parametrize('test_redditor_name', (
- 'Bree-Boo',
- ))
- def test_check_user_existence_banned(
- test_redditor_name: str,
- reddit_instance: praw.Reddit,
- downloader_mock: MagicMock,
- ):
- downloader_mock.reddit_instance = reddit_instance
- with pytest.raises(BulkDownloaderException, match='is banned'):
- RedditConnector.check_user_existence(downloader_mock, test_redditor_name)
- @pytest.mark.online
- @pytest.mark.reddit
- @pytest.mark.parametrize(('test_subreddit_name', 'expected_message'), (
- ('donaldtrump', 'cannot be found'),
- ('submitters', 'private and cannot be scraped'),
- ('lhnhfkuhwreolo', 'does not exist')
- ))
- def test_check_subreddit_status_bad(test_subreddit_name: str, expected_message: str, reddit_instance: praw.Reddit):
- test_subreddit = reddit_instance.subreddit(test_subreddit_name)
- with pytest.raises(BulkDownloaderException, match=expected_message):
- RedditConnector.check_subreddit_status(test_subreddit)
- @pytest.mark.online
- @pytest.mark.reddit
- @pytest.mark.parametrize('test_subreddit_name', (
- 'Python',
- 'Mindustry',
- 'TrollXChromosomes',
- 'all',
- ))
- def test_check_subreddit_status_good(test_subreddit_name: str, reddit_instance: praw.Reddit):
- test_subreddit = reddit_instance.subreddit(test_subreddit_name)
- RedditConnector.check_subreddit_status(test_subreddit)
|