123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960 |
- def plot_2d_eig_ratio(surf_file, val_1='min_eig', val_2='max_eig', show=False):
- """ Plot the heatmap of eigenvalue ratios, i.e., |min_eig/max_eig| of hessian """
- print('------------------------------------------------------------------')
- print('plot_2d_eig_ratio')
- print('------------------------------------------------------------------')
- print("loading surface file: " + surf_file)
- f = h5py.File(surf_file,'r')
- x = np.array(f['xcoordinates'][:])
- y = np.array(f['ycoordinates'][:])
- X, Y = np.meshgrid(x, y)
- Z1 = np.array(f[val_1][:])
- Z2 = np.array(f[val_2][:])
- # Plot 2D heatmaps with color bar using seaborn
- abs_ratio = np.absolute(np.divide(Z1, Z2))
- print(abs_ratio)
- fig = plt.figure()
- sns_plot = sns.heatmap(abs_ratio, cmap='viridis', vmin=0, vmax=.5, cbar=True,
- xticklabels=False, yticklabels=False)
- sns_plot.invert_yaxis()
- sns_plot.get_figure().savefig(surf_file + '_' + val_1 + '_' + val_2 + '_abs_ratio_heat_sns.pdf',
- dpi=300, bbox_inches='tight', format='pdf')
- # Plot 2D heatmaps with color bar using seaborn
- ratio = np.divide(Z1, Z2)
- print(ratio)
- fig = plt.figure()
- sns_plot = sns.heatmap(ratio, cmap='viridis', cbar=True, xticklabels=False, yticklabels=False)
- sns_plot.invert_yaxis()
- sns_plot.get_figure().savefig(surf_file + '_' + val_1 + '_' + val_2 + '_ratio_heat_sns.pdf',
- dpi=300, bbox_inches='tight', format='pdf')
- f.close()
- if show: plt.show()
- if __name__ == '__main__':
- parser = argparse.ArgumentParser(description='Plot 2D loss surface')
- parser.add_argument('--surf_file', '-f', default='', help='The h5 file that contains surface values')
- parser.add_argument('--dir_file', default='', help='The h5 file that contains directions')
- parser.add_argument('--proj_file', default='', help='The h5 file that contains the projected trajectories')
- parser.add_argument('--surf_name', default='train_loss', help='The type of surface to plot')
- parser.add_argument('--vmax', default=10, type=float, help='Maximum value to map')
- parser.add_argument('--vmin', default=0.1, type=float, help='Miminum value to map')
- parser.add_argument('--vlevel', default=0.5, type=float, help='plot contours every vlevel')
- parser.add_argument('--zlim', default=10, type=float, help='Maximum loss value to show')
- parser.add_argument('--show', action='store_true', default=False, help='show plots')
- args = parser.parse_args()
- if exists(args.surf_file) and exists(args.proj_file) and exists(args.dir_file):
- plot_contour_trajectory(args.surf_file, args.dir_file, args.proj_file,
- args.surf_name, args.vmin, args.vmax, args.vlevel, args.show)
- elif exists(args.proj_file) and exists(args.dir_file):
- plot_trajectory(args.proj_file, args.dir_file, args.show)
- elif exists(args.surf_file):
- plot_2d_contour(args.surf_file, args.surf_name, args.vmin, args.vmax, args.vlevel, args.show)
|