get_submissions.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import time
  2. import traceback
  3. from . import common
  4. from . import exceptions
  5. from . import pushshift
  6. from . import tsdb
  7. def _normalize_subreddit(subreddit):
  8. if subreddit is None:
  9. pass
  10. elif isinstance(subreddit, str):
  11. subreddit = common.r.subreddit(subreddit)
  12. elif not isinstance(subreddit, common.praw.models.Subreddit):
  13. raise TypeError(type(subreddit))
  14. return subreddit
  15. def _normalize_user(user):
  16. if user is None:
  17. pass
  18. elif isinstance(user, str):
  19. user = common.r.redditor(user)
  20. elif not isinstance(user, common.praw.models.Redditor):
  21. raise TypeError(type(user))
  22. return user
  23. def get_submissions(
  24. subreddit=None,
  25. username=None,
  26. lower=None,
  27. upper=None,
  28. do_supplement=True,
  29. ):
  30. '''
  31. Collect submissions across time.
  32. Please see the global DOCSTRING variable.
  33. '''
  34. if not common.is_xor(subreddit, username):
  35. raise exceptions.NotExclusive(['subreddit', 'username'])
  36. common.login()
  37. if subreddit:
  38. (database, subreddit) = tsdb.TSDB.for_subreddit(subreddit, fix_name=True)
  39. elif username:
  40. (database, username) = tsdb.TSDB.for_user(username, fix_name=True)
  41. cur = database.sql.cursor()
  42. subreddit = _normalize_subreddit(subreddit)
  43. user = _normalize_user(username)
  44. if lower == 'update':
  45. # Start from the latest submission
  46. cur.execute('SELECT created FROM submissions ORDER BY created DESC LIMIT 1')
  47. fetch = cur.fetchone()
  48. if fetch is not None:
  49. lower = fetch[0]
  50. else:
  51. lower = None
  52. if lower is None:
  53. lower = 0
  54. if username:
  55. submissions = pushshift.get_submissions_from_user(username, lower=lower, upper=upper)
  56. else:
  57. submissions = pushshift.get_submissions_from_subreddit(subreddit, lower=lower, upper=upper)
  58. if do_supplement:
  59. submissions = pushshift.supplement_reddit_data(submissions, chunk_size=100)
  60. submissions = common.generator_chunker(submissions, 200)
  61. form = '{lower} ({lower_unix}) - {upper} ({upper_unix}) +{gain}'
  62. for chunk in submissions:
  63. chunk.sort(key=lambda x: x.created_utc)
  64. step = database.insert(chunk)
  65. message = form.format(
  66. lower=common.human(chunk[0].created_utc),
  67. upper=common.human(chunk[-1].created_utc),
  68. lower_unix=int(chunk[0].created_utc),
  69. upper_unix=int(chunk[-1].created_utc),
  70. gain=step['new_submissions'],
  71. )
  72. print(message)
  73. cur.execute('SELECT COUNT(idint) FROM submissions')
  74. itemcount = cur.fetchone()[0]
  75. print('Ended with %d items in %s' % (itemcount, database.filepath.basename))
  76. def get_submissions_argparse(args):
  77. if args.lower == 'update':
  78. lower = 'update'
  79. else:
  80. lower = common.int_none(args.lower)
  81. return get_submissions(
  82. subreddit=args.subreddit,
  83. username=args.username,
  84. lower=lower,
  85. upper=common.int_none(args.upper),
  86. do_supplement=args.do_supplement,
  87. )