def plot_2d_eig_ratio(surf_file, val_1='min_eig', val_2='max_eig', show=False): """ Plot the heatmap of eigenvalue ratios, i.e., |min_eig/max_eig| of hessian """ print('------------------------------------------------------------------') print('plot_2d_eig_ratio') print('------------------------------------------------------------------') print("loading surface file: " + surf_file) f = h5py.File(surf_file,'r') x = np.array(f['xcoordinates'][:]) y = np.array(f['ycoordinates'][:]) X, Y = np.meshgrid(x, y) Z1 = np.array(f[val_1][:]) Z2 = np.array(f[val_2][:]) # Plot 2D heatmaps with color bar using seaborn abs_ratio = np.absolute(np.divide(Z1, Z2)) print(abs_ratio) fig = plt.figure() sns_plot = sns.heatmap(abs_ratio, cmap='viridis', vmin=0, vmax=.5, cbar=True, xticklabels=False, yticklabels=False) sns_plot.invert_yaxis() sns_plot.get_figure().savefig(surf_file + '_' + val_1 + '_' + val_2 + '_abs_ratio_heat_sns.pdf', dpi=300, bbox_inches='tight', format='pdf') # Plot 2D heatmaps with color bar using seaborn ratio = np.divide(Z1, Z2) print(ratio) fig = plt.figure() sns_plot = sns.heatmap(ratio, cmap='viridis', cbar=True, xticklabels=False, yticklabels=False) sns_plot.invert_yaxis() sns_plot.get_figure().savefig(surf_file + '_' + val_1 + '_' + val_2 + '_ratio_heat_sns.pdf', dpi=300, bbox_inches='tight', format='pdf') f.close() if show: plt.show() if __name__ == '__main__': parser = argparse.ArgumentParser(description='Plot 2D loss surface') parser.add_argument('--surf_file', '-f', default='', help='The h5 file that contains surface values') parser.add_argument('--dir_file', default='', help='The h5 file that contains directions') parser.add_argument('--proj_file', default='', help='The h5 file that contains the projected trajectories') parser.add_argument('--surf_name', default='train_loss', help='The type of surface to plot') parser.add_argument('--vmax', default=10, type=float, help='Maximum value to map') parser.add_argument('--vmin', default=0.1, type=float, help='Miminum value to map') parser.add_argument('--vlevel', default=0.5, type=float, help='plot contours every vlevel') parser.add_argument('--zlim', default=10, type=float, help='Maximum loss value to show') parser.add_argument('--show', action='store_true', default=False, help='show plots') args = parser.parse_args() if exists(args.surf_file) and exists(args.proj_file) and exists(args.dir_file): plot_contour_trajectory(args.surf_file, args.dir_file, args.proj_file, args.surf_name, args.vmin, args.vmax, args.vlevel, args.show) elif exists(args.proj_file) and exists(args.dir_file): plot_trajectory(args.proj_file, args.dir_file, args.show) elif exists(args.surf_file): plot_2d_contour(args.surf_file, args.surf_name, args.vmin, args.vmax, args.vlevel, args.show)