cse.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. import base64
  2. import json
  3. import logging
  4. import os
  5. from io import BytesIO
  6. from cryptography.hazmat.primitives.ciphers import Cipher
  7. from cryptography.hazmat.primitives.ciphers.aead import AESGCM
  8. from cryptography.hazmat.primitives.ciphers.algorithms import AES
  9. from cryptography.hazmat.primitives.ciphers.modes import CBC
  10. from cryptography.hazmat.primitives.padding import PKCS7
  11. logger = logging.getLogger(__name__)
  12. AES_BLOCK_SIZE = 128
  13. ALG_CBC = "AES/CBC/PKCS5Padding"
  14. ALG_GCM = "AES/GCM/NoPadding"
  15. HEADER_ALG = "x-amz-cek-alg"
  16. HEADER_KEY = "x-amz-key-v2"
  17. HEADER_IV = "x-amz-iv"
  18. HEADER_MATDESC = "x-amz-matdesc"
  19. HEADER_TAG_LEN = "x-amz-tag-len"
  20. HEADER_UE_CLENGHT = "x-amz-unencrypted-content-length"
  21. HEADER_WRAP_ALG = "x-amz-wrap-alg"
  22. def is_kms_cse_encrypted(s3_metadata):
  23. if HEADER_KEY in s3_metadata:
  24. if s3_metadata.get(HEADER_WRAP_ALG, None) != "kms":
  25. raise ValueError("Unsupported Hash strategy")
  26. if s3_metadata.get(HEADER_ALG, None) not in [ALG_CBC, ALG_GCM]:
  27. raise ValueError("Unsupported Hash algorithm")
  28. return True
  29. elif "x-amz-key" in s3_metadata:
  30. raise ValueError("Unsupported Amazon S3 Hash Client Version")
  31. return False
  32. def get_encryption_aes_key(key, kms_client):
  33. encryption_context = {"kms_cmk_id": key}
  34. response = kms_client.generate_data_key(
  35. KeyId=key, EncryptionContext=encryption_context, KeySpec="AES_256"
  36. )
  37. return (
  38. response["Plaintext"],
  39. encryption_context,
  40. base64.b64encode(response["CiphertextBlob"]).decode(),
  41. )
  42. def get_decryption_aes_key(key, material_description, kms_client):
  43. return kms_client.decrypt(
  44. CiphertextBlob=key, EncryptionContext=material_description
  45. )["Plaintext"]
  46. def encrypt(buf, s3_metadata, kms_client):
  47. """
  48. Method to encrypt an S3 object with KMS based Client-side encryption (CSE).
  49. The original object's metadata (previously used to decrypt the content) is
  50. used to infer some parameters such as the algorithm originally used to encrypt
  51. the previous version (which is left unchanged) and to store the new envelope,
  52. including the initialization vector (IV).
  53. """
  54. logger.info("Encrypting Object with CSE-KMS")
  55. content = buf.read()
  56. alg = s3_metadata.get(HEADER_ALG, None)
  57. matdesc = json.loads(s3_metadata[HEADER_MATDESC])
  58. aes_key, matdesc_metadata, key_metadata = get_encryption_aes_key(
  59. matdesc["kms_cmk_id"], kms_client
  60. )
  61. s3_metadata[HEADER_UE_CLENGHT] = str(len(content))
  62. s3_metadata[HEADER_WRAP_ALG] = "kms"
  63. s3_metadata[HEADER_KEY] = key_metadata
  64. s3_metadata[HEADER_ALG] = alg
  65. if alg == ALG_GCM:
  66. s3_metadata[HEADER_TAG_LEN] = str(AES_BLOCK_SIZE)
  67. result, iv = encrypt_gcm(aes_key, content)
  68. else:
  69. result, iv = encrypt_cbc(aes_key, content)
  70. s3_metadata[HEADER_IV] = base64.b64encode(iv).decode()
  71. return BytesIO(result), s3_metadata
  72. def decrypt(file_input, s3_metadata, kms_client):
  73. """
  74. Method to decrypt an S3 object with KMS based Client-side encryption (CSE).
  75. The object's metadata is used to fetch the encryption envelope such as
  76. the KMS key ID and the algorithm.
  77. """
  78. logger.info("Decrypting Object with CSE-KMS")
  79. alg = s3_metadata.get(HEADER_ALG, None)
  80. iv = base64.b64decode(s3_metadata[HEADER_IV])
  81. material_description = json.loads(s3_metadata[HEADER_MATDESC])
  82. key = s3_metadata[HEADER_KEY]
  83. decryption_key = base64.b64decode(key)
  84. aes_key = get_decryption_aes_key(decryption_key, material_description, kms_client)
  85. content = file_input.read()
  86. decrypted = (
  87. decrypt_gcm(content, aes_key, iv)
  88. if alg == ALG_GCM
  89. else decrypt_cbc(content, aes_key, iv)
  90. )
  91. return BytesIO(decrypted)
  92. # AES/CBC/PKCS5Padding
  93. def encrypt_cbc(aes_key, content):
  94. iv = os.urandom(16)
  95. padder = PKCS7(AES.block_size).padder()
  96. padded_result = padder.update(content) + padder.finalize()
  97. aescbc = Cipher(AES(aes_key), CBC(iv)).encryptor()
  98. result = aescbc.update(padded_result) + aescbc.finalize()
  99. return result, iv
  100. def decrypt_cbc(content, aes_key, iv):
  101. aescbc = Cipher(AES(aes_key), CBC(iv)).decryptor()
  102. padded_result = aescbc.update(content) + aescbc.finalize()
  103. unpadder = PKCS7(AES.block_size).unpadder()
  104. return unpadder.update(padded_result) + unpadder.finalize()
  105. # AES/GCM/NoPadding
  106. def encrypt_gcm(aes_key, content):
  107. iv = os.urandom(12)
  108. aesgcm = AESGCM(aes_key)
  109. result = aesgcm.encrypt(iv, content, None)
  110. return result, iv
  111. def decrypt_gcm(content, aes_key, iv):
  112. aesgcm = AESGCM(aes_key)
  113. return aesgcm.decrypt(iv, content, None)