capture_http.py 8.0 KB


  1. import threading
  2. from io import BytesIO
  3. from six.moves import http_client as httplib
  4. from contextlib import contextmanager
  5. from array import array
  6. from warcio.utils import to_native_str, BUFF_SIZE, open
  7. from warcio.warcwriter import WARCWriter, BufferWARCWriter
  8. from tempfile import SpooledTemporaryFile
  9. # ============================================================================
  10. orig_connection = httplib.HTTPConnection
  11. # ============================================================================
  12. class RecordingStream(object):
  13. def __init__(self, fp, recorder):
  14. self.fp = fp
  15. self.recorder = recorder
  16. self.recorder.set_remote_ip(self._get_remote_ip())
  17. def _get_remote_ip(self):
  18. try:
  19. fp = self.fp
  20. # for python 3, need to get 'raw' fp
  21. if hasattr(fp, 'raw'): #pragma: no cover
  22. fp = fp.raw
  23. socket = fp._sock
  24. # wrapped ssl socket
  25. if hasattr(socket, 'socket'):
  26. socket = socket.socket
  27. return socket.getpeername()[0]
  28. except Exception: #pragma: no cover
  29. return None
  30. # Used in PY2 Only
  31. def read(self, amt=None): #pragma: no cover
  32. buff = self.fp.read(amt)
  33. self.recorder.write_response(buff)
  34. return buff
  35. # Used in PY3 Only
  36. def readinto(self, buff): #pragma: no cover
  37. res = self.fp.readinto(buff)
  38. self.recorder.write_response(buff)
  39. return res
  40. def readline(self, maxlen=-1):
  41. line = self.fp.readline(maxlen)
  42. self.recorder.write_response(line)
  43. return line
  44. def close(self):
  45. self.recorder.done()
  46. if self.fp:
  47. return self.fp.close()
  48. def flush(self):
  49. return self.fp.flush()
  50. # ============================================================================
  51. class RecordingHTTPResponse(httplib.HTTPResponse):
  52. def __init__(self, recorder, *args, **kwargs):
  53. httplib.HTTPResponse.__init__(self, *args, **kwargs)
  54. self.fp = RecordingStream(self.fp, recorder)
  55. # ============================================================================
  56. class RecordingHTTPConnection(httplib.HTTPConnection):
  57. local = threading.local()
  58. def __init__(self, *args, **kwargs):
  59. orig_connection.__init__(self, *args, **kwargs)
  60. if hasattr(self.local, 'recorder'):
  61. self.recorder = self.local.recorder
  62. else:
  63. self.recorder = None
  64. def make_recording_response(*args, **kwargs):
  65. return RecordingHTTPResponse(self.recorder, *args, **kwargs)
  66. if self.recorder:
  67. self.response_class = make_recording_response
  68. def send(self, data):
  69. if not self.recorder:
  70. orig_connection.send(self, data)
  71. return
  72. def send_request(buff):
  73. self.recorder.extract_url(buff, self.host, self.port, self.default_port)
  74. orig_connection.send(self, buff)
  75. self.recorder.write_request(buff)
  76. # if sending request body as stream
  77. # (supported via httplib but seems unused via higher-level apis)
  78. if hasattr(data, 'read') and not isinstance(data, array): #pragma: no cover
  79. while True:
  80. buff = data.read(BUFF_SIZE)
  81. if not buff:
  82. break
  83. send_request(buff)
  84. else:
  85. send_request(data)
  86. def _tunnel(self, *args, **kwargs):
  87. if self.recorder:
  88. self.recorder.start_tunnel()
  89. return orig_connection._tunnel(self, *args, **kwargs)
  90. def putrequest(self, *args, **kwargs):
  91. if self.recorder:
  92. self.recorder.start()
  93. return orig_connection.putrequest(self, *args, **kwargs)
  94. # ============================================================================
  95. class RequestRecorder(object):
  96. def __init__(self, writer, filter_func=None, record_ip=True):
  97. self.writer = writer
  98. self.filter_func = filter_func
  99. self.request_out = None
  100. self.response_out = None
  101. self.url = None
  102. self.connect_host = self.connect_port = None
  103. self.started_req = False
  104. self.first_line_read = False
  105. self.lock = threading.Lock()
  106. self.warc_headers = {}
  107. self.record_ip = record_ip
  108. def start_tunnel(self):
  109. self.connect_host = self.connect_port = None
  110. self.started_req = False
  111. self.first_line_read = False
  112. def start(self):
  113. self.request_out = self._create_buffer()
  114. self.response_out = self._create_buffer()
  115. self.url = None
  116. self.started_req = True
  117. self.first_line_read = False
  118. def _create_buffer(self):
  119. return SpooledTemporaryFile(BUFF_SIZE)
  120. def set_remote_ip(self, remote_ip):
  121. if self.record_ip and remote_ip: #pragma: no cover
  122. self.warc_headers['WARC-IP-Address'] = remote_ip
  123. def write_request(self, buff):
  124. if self.started_req:
  125. self.request_out.write(buff)
  126. def write_response(self, buff):
  127. if self.started_req:
  128. self.response_out.write(buff)
  129. def _create_record(self, out, record_type):
  130. length = out.tell()
  131. out.seek(0)
  132. return self.writer.create_warc_record(
  133. warc_headers_dict=self.warc_headers,
  134. uri=self.url,
  135. record_type=record_type,
  136. payload=out,
  137. length=length)
  138. def done(self):
  139. if not self.started_req:
  140. return
  141. try:
  142. request = self._create_record(self.request_out, 'request')
  143. response = self._create_record(self.response_out, 'response')
  144. if self.filter_func:
  145. request, response = self.filter_func(request, response, self)
  146. if not request or not response:
  147. return
  148. with self.lock:
  149. self.writer.write_request_response_pair(request, response)
  150. finally:
  151. self.request_out.close()
  152. self.response_out.close()
  153. def extract_url(self, data, host, port, default_port):
  154. if self.first_line_read:
  155. return
  156. self.first_line_read = True
  157. buff = BytesIO(data)
  158. line = to_native_str(buff.readline(), 'latin-1')
  159. parts = line.split(' ', 2)
  160. verb = parts[0]
  161. path = parts[1]
  162. if verb == "CONNECT":
  163. parts = path.split(":", 1)
  164. self.connect_host = parts[0]
  165. self.connect_port = int(parts[1]) if len(parts) > 1 else default_port
  166. self.warc_headers['WARC-Proxy-Host'] = "https://{0}:{1}".format(host, port)
  167. return
  168. if self.connect_host:
  169. host = self.connect_host
  170. if self.connect_port:
  171. port = self.connect_port
  172. if path.startswith(('http:', 'https:')):
  173. self.warc_headers['WARC-Proxy-Host'] = "http://{0}:{1}".format(host, port)
  174. self.url = path
  175. return
  176. scheme = 'https' if default_port == 443 else 'http'
  177. self.url = scheme + '://' + host
  178. if port != default_port:
  179. self.url += ':' + str(port)
  180. self.url += path
  181. # ============================================================================
  182. httplib.HTTPConnection = RecordingHTTPConnection
  183. # ============================================================================
  184. @contextmanager
  185. def capture_http(warc_writer=None, filter_func=None, append=True,
  186. record_ip=True, **kwargs):
  187. out = None
  188. if warc_writer == None:
  189. if 'gzip' not in kwargs:
  190. kwargs['gzip'] = False
  191. warc_writer = BufferWARCWriter(**kwargs)
  192. if isinstance(warc_writer, str):
  193. out = open(warc_writer, 'ab' if append else 'xb')
  194. warc_writer = WARCWriter(out, **kwargs)
  195. try:
  196. recorder = RequestRecorder(warc_writer, filter_func, record_ip=record_ip)
  197. RecordingHTTPConnection.local.recorder = recorder
  198. yield warc_writer
  199. finally:
  200. RecordingHTTPConnection.local.recorder = None
  201. if out:
  202. out.close()