plot_2d_4.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. def plot_2d_eig_ratio(surf_file, val_1='min_eig', val_2='max_eig', show=False):
  2. """ Plot the heatmap of eigenvalue ratios, i.e., |min_eig/max_eig| of hessian """
  3. print('------------------------------------------------------------------')
  4. print('plot_2d_eig_ratio')
  5. print('------------------------------------------------------------------')
  6. print("loading surface file: " + surf_file)
  7. f = h5py.File(surf_file,'r')
  8. x = np.array(f['xcoordinates'][:])
  9. y = np.array(f['ycoordinates'][:])
  10. X, Y = np.meshgrid(x, y)
  11. Z1 = np.array(f[val_1][:])
  12. Z2 = np.array(f[val_2][:])
  13. # Plot 2D heatmaps with color bar using seaborn
  14. abs_ratio = np.absolute(np.divide(Z1, Z2))
  15. print(abs_ratio)
  16. fig = plt.figure()
  17. sns_plot = sns.heatmap(abs_ratio, cmap='viridis', vmin=0, vmax=.5, cbar=True,
  18. xticklabels=False, yticklabels=False)
  19. sns_plot.invert_yaxis()
  20. sns_plot.get_figure().savefig(surf_file + '_' + val_1 + '_' + val_2 + '_abs_ratio_heat_sns.pdf',
  21. dpi=300, bbox_inches='tight', format='pdf')
  22. # Plot 2D heatmaps with color bar using seaborn
  23. ratio = np.divide(Z1, Z2)
  24. print(ratio)
  25. fig = plt.figure()
  26. sns_plot = sns.heatmap(ratio, cmap='viridis', cbar=True, xticklabels=False, yticklabels=False)
  27. sns_plot.invert_yaxis()
  28. sns_plot.get_figure().savefig(surf_file + '_' + val_1 + '_' + val_2 + '_ratio_heat_sns.pdf',
  29. dpi=300, bbox_inches='tight', format='pdf')
  30. f.close()
  31. if show: plt.show()
  32. if __name__ == '__main__':
  33. parser = argparse.ArgumentParser(description='Plot 2D loss surface')
  34. parser.add_argument('--surf_file', '-f', default='', help='The h5 file that contains surface values')
  35. parser.add_argument('--dir_file', default='', help='The h5 file that contains directions')
  36. parser.add_argument('--proj_file', default='', help='The h5 file that contains the projected trajectories')
  37. parser.add_argument('--surf_name', default='train_loss', help='The type of surface to plot')
  38. parser.add_argument('--vmax', default=10, type=float, help='Maximum value to map')
  39. parser.add_argument('--vmin', default=0.1, type=float, help='Miminum value to map')
  40. parser.add_argument('--vlevel', default=0.5, type=float, help='plot contours every vlevel')
  41. parser.add_argument('--zlim', default=10, type=float, help='Maximum loss value to show')
  42. parser.add_argument('--show', action='store_true', default=False, help='show plots')
  43. args = parser.parse_args()
  44. if exists(args.surf_file) and exists(args.proj_file) and exists(args.dir_file):
  45. plot_contour_trajectory(args.surf_file, args.dir_file, args.proj_file,
  46. args.surf_name, args.vmin, args.vmax, args.vlevel, args.show)
  47. elif exists(args.proj_file) and exists(args.dir_file):
  48. plot_trajectory(args.proj_file, args.dir_file, args.show)
  49. elif exists(args.surf_file):
  50. plot_2d_contour(args.surf_file, args.surf_name, args.vmin, args.vmax, args.vlevel, args.show)