dct.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. #!/usr/bin/env python3
  2. #
  3. # Copyright(C) 2021 wuyaoping
  4. #
  5. # DCT algorithm has great a robust but lower capacity.
  6. import numpy as np
  7. import os.path as osp
  8. import cv2
  9. FLAG = '%'
  10. # Select a part location from the middle frequency
  11. LOC_MAX = (4, 1)
  12. LOC_MIN = (3, 2)
  13. # The difference between MAX and MIN,
  14. # bigger to improve robust but make picture low quality.
  15. ALPHA = 1
  16. # Quantizer table
  17. TABLE = np.array([
  18. [16, 11, 10, 16, 24, 40, 51, 61],
  19. [12, 12, 14, 19, 26, 58, 60, 55],
  20. [14, 13, 16, 24, 40, 57, 69, 56],
  21. [14, 17, 22, 29, 51, 87, 80, 62],
  22. [18, 22, 37, 56, 68, 109, 103, 77],
  23. [24, 35, 55, 64, 81, 104, 113, 92],
  24. [49, 64, 78, 87, 103, 121, 120, 101],
  25. [72, 92, 95, 98, 112, 100, 103, 99]
  26. ])
  27. def insert(path, txt):
  28. img = cv2.imread(path, cv2.IMREAD_ANYCOLOR)
  29. txt = "{}{}{}".format(len(txt), FLAG, txt)
  30. row, col = img.shape[:2]
  31. max_bytes = (row // 8) * (col // 8) // 8
  32. assert max_bytes >= len(
  33. txt), "Message overflow the capacity:{}".format(max_bytes)
  34. img = cv2.cvtColor(img, cv2.COLOR_BGR2YUV)
  35. # Just use the Y plane to store message, you can use all plane
  36. y, u, v = cv2.split(img)
  37. y = y.astype(np.float32)
  38. blocks = []
  39. # Quantize blocks
  40. for r_idx in range(0, 8 * (row // 8), 8):
  41. for c_idx in range(0, 8 * (col // 8), 8):
  42. quantized = cv2.dct(y[r_idx: r_idx+8, c_idx: c_idx+8]) / TABLE
  43. blocks.append(quantized)
  44. for idx in range(len(txt)):
  45. encode(blocks[idx*8: (idx+1)*8], txt[idx])
  46. idx = 0
  47. # Restore Y plane
  48. for r_idx in range(0, 8 * (row // 8), 8):
  49. for c_idx in range(0, 8 * (col // 8), 8):
  50. y[r_idx: r_idx+8, c_idx: c_idx+8] = cv2.idct(blocks[idx] * TABLE)
  51. idx += 1
  52. y = y.astype(np.uint8)
  53. img = cv2.cvtColor(cv2.merge((y, u, v)), cv2.COLOR_YUV2BGR)
  54. filename, _ = osp.splitext(path)
  55. # DCT algorithm can save message even if jpg
  56. filename += '_dct_embeded' + '.jpg'
  57. cv2.imwrite(filename, img)
  58. return filename
  59. # Encode a char into the blocks
  60. def encode(blocks, data):
  61. data = ord(data)
  62. for idx in range(len(blocks)):
  63. bit_val = (data >> idx) & 1
  64. max_val = max(blocks[idx][LOC_MAX], blocks[idx][LOC_MIN])
  65. min_val = min(blocks[idx][LOC_MAX], blocks[idx][LOC_MIN])
  66. if max_val - min_val <= ALPHA:
  67. max_val = min_val + ALPHA + 1e-3
  68. if bit_val == 1:
  69. blocks[idx][LOC_MAX] = max_val
  70. blocks[idx][LOC_MIN] = min_val
  71. else:
  72. blocks[idx][LOC_MAX] = min_val
  73. blocks[idx][LOC_MIN] = max_val
  74. # Decode a char from the blocks
  75. def decode(blocks):
  76. val = 0
  77. for idx in range(len(blocks)):
  78. if blocks[idx][LOC_MAX] > blocks[idx][LOC_MIN]:
  79. val |= 1 << idx
  80. return chr(val)
  81. def extract(path):
  82. img = cv2.imread(path, cv2.IMREAD_ANYCOLOR)
  83. row, col = img.shape[:2]
  84. max_bytes = (row // 8) * (col // 8) // 8
  85. img = cv2.cvtColor(img, cv2.COLOR_BGR2YUV)
  86. y, u, v = cv2.split(img)
  87. y = y.astype(np.float32)
  88. blocks = []
  89. for r_idx in range(0, 8 * (row // 8), 8):
  90. for c_idx in range(0, 8 * (col // 8), 8):
  91. quantized = cv2.dct(y[r_idx: r_idx+8, c_idx: c_idx+8]) / TABLE
  92. blocks.append(quantized)
  93. res = ''
  94. idx = 0
  95. # Extract the length of the message
  96. while idx < max_bytes:
  97. ch = decode(blocks[idx*8: (idx+1)*8])
  98. idx += 1
  99. if ch == FLAG:
  100. break
  101. res += ch
  102. end = int(res) + idx
  103. assert end <= max_bytes, "Input image isn't correct."
  104. res = ''
  105. while idx < end:
  106. res += decode(blocks[idx*8: (idx+1)*8])
  107. idx += 1
  108. return res
  109. if __name__ == '__main__':
  110. data = 'A collection of simple python mini projects to enhance your Python skills.'
  111. res_path = insert('./example.png', data)
  112. res = extract(res_path)
  113. print(res)