123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566 |
- def plot_2d_contour(surf_file, surf_name='train_loss', vmin=0.1, vmax=10, vlevel=0.5, show=False):
- """Plot 2D contour map and 3D surface."""
- f = h5py.File(surf_file, 'r')
- x = np.array(f['xcoordinates'][:])
- y = np.array(f['ycoordinates'][:])
- X, Y = np.meshgrid(x, y)
- if surf_name in f.keys():
- Z = np.array(f[surf_name][:])
- elif surf_name == 'train_err' or surf_name == 'test_err' :
- Z = 100 - np.array(f[surf_name][:])
- else:
- print ('%s is not found in %s' % (surf_name, surf_file))
- print('------------------------------------------------------------------')
- print('plot_2d_contour')
- print('------------------------------------------------------------------')
- print("loading surface file: " + surf_file)
- print('len(xcoordinates): %d len(ycoordinates): %d' % (len(x), len(y)))
- print('max(%s) = %f \t min(%s) = %f' % (surf_name, np.max(Z), surf_name, np.min(Z)))
- print(Z)
- if (len(x) <= 1 or len(y) <= 1):
- print('The length of coordinates is not enough for plotting contours')
- return
- # --------------------------------------------------------------------
- # Plot 2D contours
- # --------------------------------------------------------------------
- fig = plt.figure()
- CS = plt.contour(X, Y, Z, cmap='summer', levels=np.arange(vmin, vmax, vlevel))
- plt.clabel(CS, inline=1, fontsize=8)
- fig.savefig(surf_file + '_' + surf_name + '_2dcontour' + '.pdf', dpi=300,
- bbox_inches='tight', format='pdf')
- fig = plt.figure()
- print(surf_file + '_' + surf_name + '_2dcontourf' + '.pdf')
- CS = plt.contourf(X, Y, Z, cmap='summer', levels=np.arange(vmin, vmax, vlevel))
- fig.savefig(surf_file + '_' + surf_name + '_2dcontourf' + '.pdf', dpi=300,
- bbox_inches='tight', format='pdf')
- # --------------------------------------------------------------------
- # Plot 2D heatmaps
- # --------------------------------------------------------------------
- fig = plt.figure()
- sns_plot = sns.heatmap(Z, cmap='viridis', cbar=True, vmin=vmin, vmax=vmax,
- xticklabels=False, yticklabels=False)
- sns_plot.invert_yaxis()
- sns_plot.get_figure().savefig(surf_file + '_' + surf_name + '_2dheat.pdf',
- dpi=300, bbox_inches='tight', format='pdf')
- # --------------------------------------------------------------------
- # Plot 3D surface
- # --------------------------------------------------------------------
- fig = plt.figure()
- ax = Axes3D(fig)
- surf = ax.plot_surface(X, Y, Z, cmap=cm.coolwarm, linewidth=0, antialiased=False)
- fig.colorbar(surf, shrink=0.5, aspect=5)
- fig.savefig(surf_file + '_' + surf_name + '_3dsurface.pdf', dpi=300,
- bbox_inches='tight', format='pdf')
- f.close()
- if show: plt.show()
|