import base64 import sys from warcio.limitreader import LimitReader from warcio.utils import to_native_str, Digester from warcio.exceptions import ArchiveLoadFailed # ============================================================================ class DigestChecker(object): def __init__(self, kind=None): self._problem = [] self._passed = None self.kind = kind @property def passed(self): return self._passed @passed.setter def passed(self, value): self._passed = value @property def problems(self): return self._problem def problem(self, value, passed=False): self._problem.append(value) if self.kind == 'raise': raise ArchiveLoadFailed(value) if self.kind == 'log': sys.stderr.write(value + '\n') self._passed = passed # ============================================================================ class DigestVerifyingReader(LimitReader): """ A reader which verifies the digest of the wrapped reader """ def __init__(self, stream, limit, digest_checker, record_type=None, payload_digest=None, block_digest=None, segment_number=None): super(DigestVerifyingReader, self).__init__(stream, limit) self.digest_checker = digest_checker if record_type == 'revisit': block_digest = None payload_digest = None if segment_number is not None: #pragma: no cover payload_digest = None self.payload_digest = payload_digest self.block_digest = block_digest self.payload_digester = None self.payload_digester_obj = None self.block_digester = None if block_digest: try: algo, _ = _parse_digest(block_digest) self.block_digester = Digester(algo) except ValueError: self.digest_checker.problem('unknown hash algorithm name in block digest') self.block_digester = None if payload_digest: try: algo, _ = _parse_digest(self.payload_digest) self.payload_digester_obj = Digester(algo) except ValueError: self.digest_checker.problem('unknown hash algorithm name in payload digest') self.payload_digester_obj = None def begin_payload(self): self.payload_digester = self.payload_digester_obj if self.limit == 0: check = _compare_digest_rfc_3548(self.payload_digester, self.payload_digest) if check is False: self.digest_checker.problem('payload digest failed: {}'.format(self.payload_digest)) self.payload_digester = None # prevent double-fire elif check is True and self.digest_checker.passed is not False: self.digest_checker.passed = True def _update(self, buff): super(DigestVerifyingReader, self)._update(buff) if self.payload_digester: self.payload_digester.update(buff) if self.block_digester: self.block_digester.update(buff) if self.limit == 0: check = _compare_digest_rfc_3548(self.block_digester, self.block_digest) if check is False: self.digest_checker.problem('block digest failed: {}'.format(self.block_digest)) elif check is True and self.digest_checker.passed is not False: self.digest_checker.passed = True check = _compare_digest_rfc_3548(self.payload_digester, self.payload_digest) if check is False: self.digest_checker.problem('payload digest failed {}'.format(self.payload_digest)) elif check is True and self.digest_checker.passed is not False: self.digest_checker.passed = True return buff def _compare_digest_rfc_3548(digester, digest): ''' The WARC standard does not recommend a digest algorithm and appears to allow any encoding from RFC3548. The Python base64 module supports RFC3548 although the base64 alternate alphabet is not exactly a first class citizen. Hopefully digest algos are named with the same names used by OpenSSL. ''' if not digester or not digest: return None digester_b32 = str(digester) our_algo, our_value = _parse_digest(digester_b32) warc_algo, warc_value = _parse_digest(digest) warc_b32 = _to_b32(len(our_value), warc_value) if our_value == warc_b32: return True return False def _to_b32(length, value): ''' Convert value to base 32, given that it's supposed to have the same length as the digest we're about to compare it to ''' if len(value) == length: return value # casefold needed here? -- rfc recommends not allowing if len(value) > length: binary = base64.b16decode(value, casefold=True) else: binary = _b64_wrapper(value) return to_native_str(base64.b32encode(binary), encoding='ascii') base64_url_filename_safe_alt = b'-_' def _b64_wrapper(value): if '-' in value or '_' in value: return base64.b64decode(value, altchars=base64_url_filename_safe_alt) else: return base64.b64decode(value) def _parse_digest(digest): algo, sep, value = digest.partition(':') if sep == ':': return algo, value else: raise ValueError('could not parse digest algorithm out of '+digest)