lsb.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. #!/usr/bin/env python3
  2. #
  3. # Copyright(C) 2021 wuyaoping
  4. #
  5. # LSB algorithm has a great capacity but fragile.
  6. import cv2
  7. import math
  8. import os.path as osp
  9. import numpy as np
  10. # Insert data in the low bit.
  11. # Lower make picture less loss but lower capacity.
  12. BITS = 2
  13. HIGH_BITS = 256 - (1 << BITS)
  14. LOW_BITS = (1 << BITS) - 1
  15. BYTES_PER_BYTE = math.ceil(8 / BITS)
  16. FLAG = '%'
  17. def insert(path, txt):
  18. img = cv2.imread(path, cv2.IMREAD_ANYCOLOR)
  19. # Save origin shape to restore image
  20. ori_shape = img.shape
  21. max_bytes = ori_shape[0] * ori_shape[1] // BYTES_PER_BYTE
  22. # Encode message with length
  23. txt = '{}{}{}'.format(len(txt), FLAG, txt)
  24. assert max_bytes >= len(
  25. txt), "Message overflow the capacity:{}".format(max_bytes)
  26. data = np.reshape(img, -1)
  27. for (idx, val) in enumerate(txt):
  28. encode(data[idx*BYTES_PER_BYTE: (idx+1) * BYTES_PER_BYTE], val)
  29. img = np.reshape(data, ori_shape)
  30. filename, _ = osp.splitext(path)
  31. # png is lossless encode that can restore message correctly
  32. filename += '_lsb_embeded' + ".png"
  33. cv2.imwrite(filename, img)
  34. return filename
  35. def extract(path):
  36. img = cv2.imread(path, cv2.IMREAD_ANYCOLOR)
  37. data = np.reshape(img, -1)
  38. total = data.shape[0]
  39. res = ''
  40. idx = 0
  41. # Decode message length
  42. while idx < total // BYTES_PER_BYTE:
  43. ch = decode(data[idx*BYTES_PER_BYTE: (idx+1)*BYTES_PER_BYTE])
  44. idx += 1
  45. if ch == FLAG:
  46. break
  47. res += ch
  48. end = int(res) + idx
  49. assert end <= total // BYTES_PER_BYTE, "Input image isn't correct."
  50. res = ''
  51. while idx < end:
  52. res += decode(data[idx*BYTES_PER_BYTE: (idx+1)*BYTES_PER_BYTE])
  53. idx += 1
  54. return res
  55. def encode(block, data):
  56. data = ord(data)
  57. for idx in range(len(block)):
  58. block[idx] &= HIGH_BITS
  59. block[idx] |= (data >> (BITS * idx)) & LOW_BITS
  60. def decode(block):
  61. val = 0
  62. for idx in range(len(block)):
  63. val |= (block[idx] & LOW_BITS) << (idx * BITS)
  64. return chr(val)
  65. if __name__ == '__main__':
  66. data = 'A collection of simple python mini projects to enhance your Python skills.'
  67. input_path = "./example.png"
  68. res_path = insert(input_path, data)
  69. res = extract(res_path)
  70. print(res)