def biplot(cscore=None, loadings=None, labels=None, var1=None, var2=None, var3=None, axlabelfontsize=9, axlabelfontname="Arial", figtype='png', r=300, show=False, markerdot="o", dotsize=6, valphadot=1, colordot='#eba487', arrowcolor='#87ceeb', valphaarrow=1, arrowlinestyle='-', arrowlinewidth=0.5, centerlines=True, colorlist=None, legendpos='best', datapoints=True, dim=(6, 4), theme=None): if theme == 'dark': general.dark_bg() assert cscore is not None and loadings is not None and labels is not None and var1 is not None and var2 is not None, \ "cscore or loadings or labels or var1 or var2 are missing" if var1 is not None and var2 is not None and var3 is None: xscale = 1.0 / (cscore[:, 0].max() - cscore[:, 0].min()) yscale = 1.0 / (cscore[:, 1].max() - cscore[:, 1].min()) # zscale = 1.0 / (cscore[:, 2].max() - cscore[:, 2].min()) # colorlist is an array of classes from dataframe column plt.subplots(figsize=dim) if datapoints: if colorlist is not None: unique_class = set(colorlist) # color_dict = dict() assign_values = {col: i for i, col in enumerate(unique_class)} color_result_num = [assign_values[i] for i in colorlist] if colordot and isinstance(colordot, (tuple, list)): colour_map = ListedColormap(colordot) # for i in range(len(list(unique_class))): # color_dict[list(unique_class)[i]] = colordot[i] # color_result = [color_dict[i] for i in colorlist] s = plt.scatter(cscore[:, 0] * xscale, cscore[:, 1] * yscale, c=color_result_num, cmap=colour_map, s=dotsize, alpha=valphadot, marker=markerdot) plt.legend(handles=s.legend_elements()[0], labels=list(unique_class), loc=legendpos) elif colordot and not isinstance(colordot, (tuple, list)): # s = plt.scatter(cscore[:, 0] * xscale, cscore[:, 1] * yscale, color=color_result, s=dotsize, # alpha=valphadot, marker=markerdot) # plt.legend(handles=s.legend_elements()[0], labels=list(unique_class)) s = plt.scatter(cscore[:, 0] * xscale, cscore[:, 1] * yscale, c=color_result_num, s=dotsize, alpha=valphadot, marker=markerdot) plt.legend(handles=s.legend_elements()[0], labels=list(unique_class), loc=legendpos) else: plt.scatter(cscore[:, 0] * xscale, cscore[:, 1] * yscale, color=colordot, s=dotsize, alpha=valphadot, marker=markerdot) if centerlines: plt.axhline(y=0, linestyle='--', color='#7d7d7d', linewidth=1) plt.axvline(x=0, linestyle='--', color='#7d7d7d', linewidth=1) # loadings[0] is the number of the original variables # this is important where variables more than number of observations for i in range(len(loadings[0])): plt.arrow(0, 0, loadings[0][i], loadings[1][i], color=arrowcolor, alpha=valphaarrow, ls=arrowlinestyle, lw=arrowlinewidth) plt.text(loadings[0][i], loadings[1][i], labels[i]) # adjust_text(t) # plt.xlim(min(loadings[0]) - 0.1, max(loadings[0]) + 0.1) # plt.ylim(min(loadings[1]) - 0.1, max(loadings[1]) + 0.1) xlimit_max = np.max([np.max(cscore[:, 0]*xscale), np.max(loadings[0])]) xlimit_min = np.min([np.min(cscore[:, 0]*xscale), np.min(loadings[0])]) ylimit_max = np.max([np.max(cscore[:, 1]*yscale), np.max(loadings[1])]) ylimit_min = np.min([np.min(cscore[:, 1]*yscale), np.min(loadings[1])]) plt.xlim(xlimit_min-0.2, xlimit_max+0.2) plt.ylim(ylimit_min-0.2, ylimit_max+0.2) general.axis_labels("PC1 ({}%)".format(var1), "PC2 ({}%)".format(var2), axlabelfontsize, axlabelfontname) general.get_figure(show, r, figtype, 'biplot_2d', theme) # 3D if var1 is not None and var2 is not None and var3 is not None: xscale = 1.0 / (cscore[:, 0].max() - cscore[:, 0].min()) yscale = 1.0 / (cscore[:, 1].max() - cscore[:, 1].min()) zscale = 1.0 / (cscore[:, 2].max() - cscore[:, 2].min()) fig = plt.figure(figsize=dim) ax = fig.add_subplot(111, projection='3d') if datapoints: if colorlist is not None: unique_class = set(colorlist) assign_values = {col: i for i, col in enumerate(unique_class)} color_result_num = [assign_values[i] for i in colorlist] if colordot and isinstance(colordot, (tuple, list)): colour_map = ListedColormap(colordot) s = ax.scatter(cscore[:, 0]*xscale, cscore[:, 1]*yscale, cscore[:, 2]*zscale, c=color_result_num, cmap=colour_map, s=dotsize, alpha=valphadot, marker=markerdot) plt.legend(handles=s.legend_elements()[0], labels=list(unique_class), loc=legendpos) elif colordot and not isinstance(colordot, (tuple, list)): s = ax.scatter(cscore[:, 0]*xscale, cscore[:, 1]*yscale, cscore[:, 2]*zscale, c=color_result_num, s=dotsize, alpha=valphadot, marker=markerdot) plt.legend(handles=s.legend_elements()[0], labels=list(unique_class), loc=legendpos) else: ax.scatter(cscore[:, 0] * xscale, cscore[:, 1] * yscale, cscore[:, 2] * zscale, color=colordot, s=dotsize, alpha=valphadot, marker=markerdot) for i in range(len(loadings[0])): ax.quiver(0, 0, 0, loadings[0][i], loadings[1][i], loadings[2][i], color=arrowcolor, alpha=valphaarrow, ls=arrowlinestyle, lw=arrowlinewidth) ax.text(loadings[0][i], loadings[1][i], loadings[2][i], labels[i]) xlimit_max = np.max([np.max(cscore[:, 0] * xscale), np.max(loadings[0])]) xlimit_min = np.min([np.min(cscore[:, 0] * xscale), np.min(loadings[0])]) ylimit_max = np.max([np.max(cscore[:, 1] * yscale), np.max(loadings[1])]) ylimit_min = np.min([np.min(cscore[:, 1] * yscale), np.min(loadings[1])]) zlimit_max = np.max([np.max(cscore[:, 2] * zscale), np.max(loadings[2])]) zlimit_min = np.min([np.min(cscore[:, 2] * zscale), np.min(loadings[2])]) # ax.set_xlim(min(loadings[0])-0.1, max(loadings[0])+0.1) # ax.set_ylim(min(loadings[1])-0.1, max(loadings[1])+0.1) # ax.set_zlim(min(loadings[2])-0.1, max(loadings[2])+0.1) ax.set_xlim(xlimit_min-0.2, xlimit_max+0.2) ax.set_ylim(ylimit_min-0.2, ylimit_max+0.2) ax.set_zlim(zlimit_min-0.2, zlimit_max+0.2) ax.set_xlabel("PC1 ({}%)".format(var1), fontsize=axlabelfontsize, fontname=axlabelfontname) ax.set_ylabel("PC2 ({}%)".format(var2), fontsize=axlabelfontsize, fontname=axlabelfontname) ax.set_zlabel("PC3 ({}%)".format(var3), fontsize=axlabelfontsize, fontname=axlabelfontname) general.get_figure(show, r, figtype, 'biplot_3d', theme)