plot_2d_1.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. def plot_2d_contour(surf_file, surf_name='train_loss', vmin=0.1, vmax=10, vlevel=0.5, show=False):
  2. """Plot 2D contour map and 3D surface."""
  3. f = h5py.File(surf_file, 'r')
  4. x = np.array(f['xcoordinates'][:])
  5. y = np.array(f['ycoordinates'][:])
  6. X, Y = np.meshgrid(x, y)
  7. if surf_name in f.keys():
  8. Z = np.array(f[surf_name][:])
  9. elif surf_name == 'train_err' or surf_name == 'test_err' :
  10. Z = 100 - np.array(f[surf_name][:])
  11. else:
  12. print ('%s is not found in %s' % (surf_name, surf_file))
  13. print('------------------------------------------------------------------')
  14. print('plot_2d_contour')
  15. print('------------------------------------------------------------------')
  16. print("loading surface file: " + surf_file)
  17. print('len(xcoordinates): %d len(ycoordinates): %d' % (len(x), len(y)))
  18. print('max(%s) = %f \t min(%s) = %f' % (surf_name, np.max(Z), surf_name, np.min(Z)))
  19. print(Z)
  20. if (len(x) <= 1 or len(y) <= 1):
  21. print('The length of coordinates is not enough for plotting contours')
  22. return
  23. # --------------------------------------------------------------------
  24. # Plot 2D contours
  25. # --------------------------------------------------------------------
  26. fig = plt.figure()
  27. CS = plt.contour(X, Y, Z, cmap='summer', levels=np.arange(vmin, vmax, vlevel))
  28. plt.clabel(CS, inline=1, fontsize=8)
  29. fig.savefig(surf_file + '_' + surf_name + '_2dcontour' + '.pdf', dpi=300,
  30. bbox_inches='tight', format='pdf')
  31. fig = plt.figure()
  32. print(surf_file + '_' + surf_name + '_2dcontourf' + '.pdf')
  33. CS = plt.contourf(X, Y, Z, cmap='summer', levels=np.arange(vmin, vmax, vlevel))
  34. fig.savefig(surf_file + '_' + surf_name + '_2dcontourf' + '.pdf', dpi=300,
  35. bbox_inches='tight', format='pdf')
  36. # --------------------------------------------------------------------
  37. # Plot 2D heatmaps
  38. # --------------------------------------------------------------------
  39. fig = plt.figure()
  40. sns_plot = sns.heatmap(Z, cmap='viridis', cbar=True, vmin=vmin, vmax=vmax,
  41. xticklabels=False, yticklabels=False)
  42. sns_plot.invert_yaxis()
  43. sns_plot.get_figure().savefig(surf_file + '_' + surf_name + '_2dheat.pdf',
  44. dpi=300, bbox_inches='tight', format='pdf')
  45. # --------------------------------------------------------------------
  46. # Plot 3D surface
  47. # --------------------------------------------------------------------
  48. fig = plt.figure()
  49. ax = Axes3D(fig)
  50. surf = ax.plot_surface(X, Y, Z, cmap=cm.coolwarm, linewidth=0, antialiased=False)
  51. fig.colorbar(surf, shrink=0.5, aspect=5)
  52. fig.savefig(surf_file + '_' + surf_name + '_3dsurface.pdf', dpi=300,
  53. bbox_inches='tight', format='pdf')
  54. f.close()
  55. if show: plt.show()