oauth2.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. #!/usr/bin/env python3
  2. # coding=utf-8
  3. import configparser
  4. import logging
  5. import random
  6. import re
  7. import socket
  8. from pathlib import Path
  9. import praw
  10. import requests
  11. from bdfr.exceptions import BulkDownloaderException, RedditAuthenticationError
  12. logger = logging.getLogger(__name__)
  13. class OAuth2Authenticator:
  14. def __init__(self, wanted_scopes: set[str], client_id: str, client_secret: str):
  15. self._check_scopes(wanted_scopes)
  16. self.scopes = wanted_scopes
  17. self.client_id = client_id
  18. self.client_secret = client_secret
  19. @staticmethod
  20. def _check_scopes(wanted_scopes: set[str]):
  21. response = requests.get('https://www.reddit.com/api/v1/scopes.json',
  22. headers={'User-Agent': 'fetch-scopes test'})
  23. known_scopes = [scope for scope, data in response.json().items()]
  24. known_scopes.append('*')
  25. for scope in wanted_scopes:
  26. if scope not in known_scopes:
  27. raise BulkDownloaderException(f'Scope {scope} is not known to reddit')
  28. @staticmethod
  29. def split_scopes(scopes: str) -> set[str]:
  30. scopes = re.split(r'[,: ]+', scopes)
  31. return set(scopes)
  32. def retrieve_new_token(self) -> str:
  33. reddit = praw.Reddit(
  34. redirect_uri='http://localhost:7634',
  35. user_agent='obtain_refresh_token for BDFR',
  36. client_id=self.client_id,
  37. client_secret=self.client_secret)
  38. state = str(random.randint(0, 65000))
  39. url = reddit.auth.url(self.scopes, state, 'permanent')
  40. logger.warning('Authentication action required before the program can proceed')
  41. logger.warning(f'Authenticate at {url}')
  42. client = self.receive_connection()
  43. data = client.recv(1024).decode('utf-8')
  44. param_tokens = data.split(' ', 2)[1].split('?', 1)[1].split('&')
  45. params = {key: value for (key, value) in [token.split('=') for token in param_tokens]}
  46. if state != params['state']:
  47. self.send_message(client)
  48. raise RedditAuthenticationError(f'State mismatch in OAuth2. Expected: {state} Received: {params["state"]}')
  49. elif 'error' in params:
  50. self.send_message(client)
  51. raise RedditAuthenticationError(f'Error in OAuth2: {params["error"]}')
  52. self.send_message(client, "<script>alert('You can go back to terminal window now.')</script>")
  53. refresh_token = reddit.auth.authorize(params["code"])
  54. return refresh_token
  55. @staticmethod
  56. def receive_connection() -> socket.socket:
  57. server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  58. server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  59. server.bind(('0.0.0.0', 7634))
  60. logger.log(9, 'Server listening on 0.0.0.0:7634')
  61. server.listen(1)
  62. client = server.accept()[0]
  63. server.close()
  64. logger.log(9, 'Server closed')
  65. return client
  66. @staticmethod
  67. def send_message(client: socket.socket, message: str = ''):
  68. client.send(f'HTTP/1.1 200 OK\r\n\r\n{message}'.encode('utf-8'))
  69. client.close()
  70. class OAuth2TokenManager(praw.reddit.BaseTokenManager):
  71. def __init__(self, config: configparser.ConfigParser, config_location: Path):
  72. super(OAuth2TokenManager, self).__init__()
  73. self.config = config
  74. self.config_location = config_location
  75. def pre_refresh_callback(self, authorizer: praw.reddit.Authorizer):
  76. if authorizer.refresh_token is None:
  77. if self.config.has_option('DEFAULT', 'user_token'):
  78. authorizer.refresh_token = self.config.get('DEFAULT', 'user_token')
  79. logger.log(9, 'Loaded OAuth2 token for authoriser')
  80. else:
  81. raise RedditAuthenticationError('No auth token loaded in configuration')
  82. def post_refresh_callback(self, authorizer: praw.reddit.Authorizer):
  83. self.config.set('DEFAULT', 'user_token', authorizer.refresh_token)
  84. with open(self.config_location, 'w') as file:
  85. self.config.write(file, True)
  86. logger.log(9, f'Written OAuth2 token from authoriser to {self.config_location}')