visualize_attention_map.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. # encoding: utf-8
  2. """
  3. @author: rentianhe
  4. @contact: 596106517@qq.com
  5. """
  6. import numpy as np
  7. import cv2
  8. from PIL import Image
  9. import matplotlib
  10. matplotlib.use('Agg')
  11. import matplotlib.pyplot as plt
  12. import os
  13. def visualize_grid_attention(img_path, save_path, attention_mask, ratio=0.5, save_image=True, save_original_image=True, quality=100):
  14. """
  15. img_path: where to load the image
  16. save_path: where to save the image
  17. attention_mask: the 2-D attention mask on your image, e.g: np.array (h, w) or (w, h)
  18. ratio: scaling factor to scale the output h and w
  19. quality: save image quality
  20. """
  21. print("load image from: " + img_path)
  22. img = Image.open(img_path)
  23. img_h, img_w = img.size[0], img.size[1]
  24. plt.subplots(nrows=1, ncols=1, figsize=(0.02 * img_h, 0.02 * img_w))
  25. # scale the image
  26. img_h, img_w = int(img.size[0] * ratio), int(img.size[1] * ratio)
  27. img = img.resize((img_h, img_w))
  28. plt.imshow(img, alpha=1)
  29. plt.axis('off')
  30. # normalize the attention map
  31. mask = cv2.resize(attention_mask, (img_h, img_w)) # you may change the (img_w, img_h) order to adjust the attention mask
  32. normed_mask = mask / mask.max()
  33. normed_mask = cv2.resize(normed_mask, img.size)[..., np.newaxis]
  34. # put the attention map on the original image
  35. result = (img * normed_mask).astype("uint8")
  36. plt.imshow(result, alpha=1)
  37. # save image
  38. if save_image:
  39. # build save path
  40. if not os.path.exists(save_path):
  41. os.mkdir(save_path)
  42. assert save_image is not None, "you need to set where to store the picture"
  43. img_name = img_path.split('/')[-1].split('.')[0] + "_with_attention.jpg"
  44. img_with_attention_save_path = os.path.join(save_path, img_name)
  45. # pre-process before saving
  46. print("save image to: " + save_path)
  47. plt.axis('off')
  48. plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
  49. plt.margins(0, 0)
  50. plt.savefig(img_with_attention_save_path, dpi=quality)
  51. # save original image
  52. if save_original_image:
  53. # build save path
  54. if not os.path.exists(save_path):
  55. os.mkdir(save_path)
  56. print("save original image at the same time")
  57. img_name = img_path.split('/')[-1].split('.')[0] + "_original.jpg"
  58. original_image_save_path = os.path.join(save_path, img_name)
  59. img.save(original_image_save_path, quality=quality)