mongodb.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import re
  2. from enum import Enum
  3. from functools import wraps
  4. from logging import getLogger, basicConfig, INFO
  5. from pymongo import MongoClient
  6. from pymongo.errors import ServerSelectionTimeoutError, PyMongoError
  7. logger = getLogger(__name__)
  8. basicConfig(format="%(message)s", level=INFO)
  9. def with_retry(tries):
  10. def outer_wrapper(f):
  11. @wraps(f)
  12. def inner_wrapper(*args, **kwargs):
  13. def _retry(t=tries):
  14. if t <= 0:
  15. logger.error("unable to write hit to database")
  16. # raise WriteError(f"unable to write to database")
  17. return
  18. try:
  19. f(*args, **kwargs)
  20. except PyMongoError:
  21. t -= 1
  22. _retry(t)
  23. return _retry()
  24. return inner_wrapper
  25. return outer_wrapper
  26. class StringEnum(Enum):
  27. def __str__(self):
  28. return str(self.value["text"])
  29. def __repr__(self):
  30. return str(self.value["key"])
  31. class Access(StringEnum):
  32. PUBLIC = {"key": "+", "text": "public"}
  33. PRIVATE = {"key": "-", "text": "private"}
  34. class Hit:
  35. def __init__(self, url: str, access: Access):
  36. self.url = url
  37. self.access = access
  38. def __iter__(self):
  39. yield from {"url": self.url, "access": str(self.access)}.items()
  40. def is_valid(self):
  41. return (
  42. re.match(r"^https?://.*\.amazonaws.com/.*$", self.url)
  43. and self.access in Access
  44. )
  45. class MongoDB:
  46. def __init__(
  47. self,
  48. host: str = "0.0.0.0",
  49. port: int = 27017,
  50. db_name: str = "s3recon",
  51. col_name: str = "hits",
  52. unique_indicies: tuple = ("url",),
  53. indicies: tuple = ("access",),
  54. timeout: int = 10,
  55. ):
  56. self.client = MongoClient(host, port, serverSelectionTimeoutMS=timeout)
  57. self.db_name = db_name
  58. self.col_name = col_name
  59. self.index(unique_indicies, unique=True)
  60. self.index(indicies)
  61. def __del__(self):
  62. self.client.close()
  63. def index(self, indicies=(), **kwargs):
  64. for i in indicies:
  65. self.client[self.db_name][self.col_name].ensure_index(i, **kwargs)
  66. @staticmethod
  67. def normalize(item):
  68. if isinstance(item, (list, set)):
  69. return list(map(dict, item))
  70. else:
  71. return dict(item)
  72. @with_retry(3)
  73. def insert_many(self, items):
  74. self.client[self.db_name][self.col_name].insert_many(self.normalize(items))
  75. @with_retry(3)
  76. def insert(self, item):
  77. self.client[self.db_name][self.col_name].insert(self.normalize(item))
  78. @with_retry(3)
  79. def update_many(self, filter, items):
  80. self.client[self.db_name][self.col_name].update_many(
  81. filter, self.normalize(items), upsert=True
  82. )
  83. @with_retry(3)
  84. def update(self, filter, item):
  85. self.client[self.db_name][self.col_name].update(
  86. filter, self.normalize(item), upsert=True
  87. )
  88. def is_connected(self):
  89. try:
  90. self.client.server_info()
  91. except ServerSelectionTimeoutError:
  92. return False
  93. return True