123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107 |
- #!/usr/bin/env python3
- # coding=utf-8
- import configparser
- import logging
- import random
- import re
- import socket
- from pathlib import Path
- import praw
- import requests
- from bdfr.exceptions import BulkDownloaderException, RedditAuthenticationError
- logger = logging.getLogger(__name__)
- class OAuth2Authenticator:
- def __init__(self, wanted_scopes: set[str], client_id: str, client_secret: str):
- self._check_scopes(wanted_scopes)
- self.scopes = wanted_scopes
- self.client_id = client_id
- self.client_secret = client_secret
- @staticmethod
- def _check_scopes(wanted_scopes: set[str]):
- response = requests.get('https://www.reddit.com/api/v1/scopes.json',
- headers={'User-Agent': 'fetch-scopes test'})
- known_scopes = [scope for scope, data in response.json().items()]
- known_scopes.append('*')
- for scope in wanted_scopes:
- if scope not in known_scopes:
- raise BulkDownloaderException(f'Scope {scope} is not known to reddit')
- @staticmethod
- def split_scopes(scopes: str) -> set[str]:
- scopes = re.split(r'[,: ]+', scopes)
- return set(scopes)
- def retrieve_new_token(self) -> str:
- reddit = praw.Reddit(
- redirect_uri='http://localhost:7634',
- user_agent='obtain_refresh_token for BDFR',
- client_id=self.client_id,
- client_secret=self.client_secret)
- state = str(random.randint(0, 65000))
- url = reddit.auth.url(self.scopes, state, 'permanent')
- logger.warning('Authentication action required before the program can proceed')
- logger.warning(f'Authenticate at {url}')
- client = self.receive_connection()
- data = client.recv(1024).decode('utf-8')
- param_tokens = data.split(' ', 2)[1].split('?', 1)[1].split('&')
- params = {key: value for (key, value) in [token.split('=') for token in param_tokens]}
- if state != params['state']:
- self.send_message(client)
- raise RedditAuthenticationError(f'State mismatch in OAuth2. Expected: {state} Received: {params["state"]}')
- elif 'error' in params:
- self.send_message(client)
- raise RedditAuthenticationError(f'Error in OAuth2: {params["error"]}')
- self.send_message(client, "<script>alert('You can go back to terminal window now.')</script>")
- refresh_token = reddit.auth.authorize(params["code"])
- return refresh_token
- @staticmethod
- def receive_connection() -> socket.socket:
- server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- server.bind(('0.0.0.0', 7634))
- logger.log(9, 'Server listening on 0.0.0.0:7634')
- server.listen(1)
- client = server.accept()[0]
- server.close()
- logger.log(9, 'Server closed')
- return client
- @staticmethod
- def send_message(client: socket.socket, message: str = ''):
- client.send(f'HTTP/1.1 200 OK\r\n\r\n{message}'.encode('utf-8'))
- client.close()
- class OAuth2TokenManager(praw.reddit.BaseTokenManager):
- def __init__(self, config: configparser.ConfigParser, config_location: Path):
- super(OAuth2TokenManager, self).__init__()
- self.config = config
- self.config_location = config_location
- def pre_refresh_callback(self, authorizer: praw.reddit.Authorizer):
- if authorizer.refresh_token is None:
- if self.config.has_option('DEFAULT', 'user_token'):
- authorizer.refresh_token = self.config.get('DEFAULT', 'user_token')
- logger.log(9, 'Loaded OAuth2 token for authoriser')
- else:
- raise RedditAuthenticationError('No auth token loaded in configuration')
- def post_refresh_callback(self, authorizer: praw.reddit.Authorizer):
- self.config.set('DEFAULT', 'user_token', authorizer.refresh_token)
- with open(self.config_location, 'w') as file:
- self.config.write(file, True)
- logger.log(9, f'Written OAuth2 token from authoriser to {self.config_location}')
|