123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105 |
- def biplot(cscore=None, loadings=None, labels=None, var1=None, var2=None, var3=None, axlabelfontsize=9, axlabelfontname="Arial",
- figtype='png', r=300, show=False, markerdot="o", dotsize=6, valphadot=1, colordot='#eba487', arrowcolor='#87ceeb',
- valphaarrow=1, arrowlinestyle='-', arrowlinewidth=0.5, centerlines=True, colorlist=None, legendpos='best',
- datapoints=True, dim=(6, 4), theme=None):
- if theme == 'dark':
- general.dark_bg()
- assert cscore is not None and loadings is not None and labels is not None and var1 is not None and var2 is not None, \
- "cscore or loadings or labels or var1 or var2 are missing"
- if var1 is not None and var2 is not None and var3 is None:
- xscale = 1.0 / (cscore[:, 0].max() - cscore[:, 0].min())
- yscale = 1.0 / (cscore[:, 1].max() - cscore[:, 1].min())
- # zscale = 1.0 / (cscore[:, 2].max() - cscore[:, 2].min())
- # colorlist is an array of classes from dataframe column
- plt.subplots(figsize=dim)
- if datapoints:
- if colorlist is not None:
- unique_class = set(colorlist)
- # color_dict = dict()
- assign_values = {col: i for i, col in enumerate(unique_class)}
- color_result_num = [assign_values[i] for i in colorlist]
- if colordot and isinstance(colordot, (tuple, list)):
- colour_map = ListedColormap(colordot)
- # for i in range(len(list(unique_class))):
- # color_dict[list(unique_class)[i]] = colordot[i]
- # color_result = [color_dict[i] for i in colorlist]
- s = plt.scatter(cscore[:, 0] * xscale, cscore[:, 1] * yscale, c=color_result_num, cmap=colour_map,
- s=dotsize, alpha=valphadot, marker=markerdot)
- plt.legend(handles=s.legend_elements()[0], labels=list(unique_class), loc=legendpos)
- elif colordot and not isinstance(colordot, (tuple, list)):
- # s = plt.scatter(cscore[:, 0] * xscale, cscore[:, 1] * yscale, color=color_result, s=dotsize,
- # alpha=valphadot, marker=markerdot)
- # plt.legend(handles=s.legend_elements()[0], labels=list(unique_class))
- s = plt.scatter(cscore[:, 0] * xscale, cscore[:, 1] * yscale, c=color_result_num, s=dotsize,
- alpha=valphadot, marker=markerdot)
- plt.legend(handles=s.legend_elements()[0], labels=list(unique_class), loc=legendpos)
- else:
- plt.scatter(cscore[:, 0] * xscale, cscore[:, 1] * yscale, color=colordot, s=dotsize,
- alpha=valphadot, marker=markerdot)
- if centerlines:
- plt.axhline(y=0, linestyle='--', color='#7d7d7d', linewidth=1)
- plt.axvline(x=0, linestyle='--', color='#7d7d7d', linewidth=1)
- # loadings[0] is the number of the original variables
- # this is important where variables more than number of observations
- for i in range(len(loadings[0])):
- plt.arrow(0, 0, loadings[0][i], loadings[1][i], color=arrowcolor, alpha=valphaarrow, ls=arrowlinestyle,
- lw=arrowlinewidth)
- plt.text(loadings[0][i], loadings[1][i], labels[i])
- # adjust_text(t)
- # plt.xlim(min(loadings[0]) - 0.1, max(loadings[0]) + 0.1)
- # plt.ylim(min(loadings[1]) - 0.1, max(loadings[1]) + 0.1)
- xlimit_max = np.max([np.max(cscore[:, 0]*xscale), np.max(loadings[0])])
- xlimit_min = np.min([np.min(cscore[:, 0]*xscale), np.min(loadings[0])])
- ylimit_max = np.max([np.max(cscore[:, 1]*yscale), np.max(loadings[1])])
- ylimit_min = np.min([np.min(cscore[:, 1]*yscale), np.min(loadings[1])])
- plt.xlim(xlimit_min-0.2, xlimit_max+0.2)
- plt.ylim(ylimit_min-0.2, ylimit_max+0.2)
- general.axis_labels("PC1 ({}%)".format(var1), "PC2 ({}%)".format(var2), axlabelfontsize, axlabelfontname)
- general.get_figure(show, r, figtype, 'biplot_2d', theme)
- # 3D
- if var1 is not None and var2 is not None and var3 is not None:
- xscale = 1.0 / (cscore[:, 0].max() - cscore[:, 0].min())
- yscale = 1.0 / (cscore[:, 1].max() - cscore[:, 1].min())
- zscale = 1.0 / (cscore[:, 2].max() - cscore[:, 2].min())
- fig = plt.figure(figsize=dim)
- ax = fig.add_subplot(111, projection='3d')
- if datapoints:
- if colorlist is not None:
- unique_class = set(colorlist)
- assign_values = {col: i for i, col in enumerate(unique_class)}
- color_result_num = [assign_values[i] for i in colorlist]
- if colordot and isinstance(colordot, (tuple, list)):
- colour_map = ListedColormap(colordot)
- s = ax.scatter(cscore[:, 0]*xscale, cscore[:, 1]*yscale, cscore[:, 2]*zscale, c=color_result_num,
- cmap=colour_map, s=dotsize, alpha=valphadot, marker=markerdot)
- plt.legend(handles=s.legend_elements()[0], labels=list(unique_class), loc=legendpos)
- elif colordot and not isinstance(colordot, (tuple, list)):
- s = ax.scatter(cscore[:, 0]*xscale, cscore[:, 1]*yscale, cscore[:, 2]*zscale, c=color_result_num,
- s=dotsize, alpha=valphadot, marker=markerdot)
- plt.legend(handles=s.legend_elements()[0], labels=list(unique_class), loc=legendpos)
- else:
- ax.scatter(cscore[:, 0] * xscale, cscore[:, 1] * yscale, cscore[:, 2] * zscale, color=colordot,
- s=dotsize, alpha=valphadot, marker=markerdot)
- for i in range(len(loadings[0])):
- ax.quiver(0, 0, 0, loadings[0][i], loadings[1][i], loadings[2][i], color=arrowcolor, alpha=valphaarrow,
- ls=arrowlinestyle, lw=arrowlinewidth)
- ax.text(loadings[0][i], loadings[1][i], loadings[2][i], labels[i])
- xlimit_max = np.max([np.max(cscore[:, 0] * xscale), np.max(loadings[0])])
- xlimit_min = np.min([np.min(cscore[:, 0] * xscale), np.min(loadings[0])])
- ylimit_max = np.max([np.max(cscore[:, 1] * yscale), np.max(loadings[1])])
- ylimit_min = np.min([np.min(cscore[:, 1] * yscale), np.min(loadings[1])])
- zlimit_max = np.max([np.max(cscore[:, 2] * zscale), np.max(loadings[2])])
- zlimit_min = np.min([np.min(cscore[:, 2] * zscale), np.min(loadings[2])])
- # ax.set_xlim(min(loadings[0])-0.1, max(loadings[0])+0.1)
- # ax.set_ylim(min(loadings[1])-0.1, max(loadings[1])+0.1)
- # ax.set_zlim(min(loadings[2])-0.1, max(loadings[2])+0.1)
- ax.set_xlim(xlimit_min-0.2, xlimit_max+0.2)
- ax.set_ylim(ylimit_min-0.2, ylimit_max+0.2)
- ax.set_zlim(zlimit_min-0.2, zlimit_max+0.2)
- ax.set_xlabel("PC1 ({}%)".format(var1), fontsize=axlabelfontsize, fontname=axlabelfontname)
- ax.set_ylabel("PC2 ({}%)".format(var2), fontsize=axlabelfontsize, fontname=axlabelfontname)
- ax.set_zlabel("PC3 ({}%)".format(var3), fontsize=axlabelfontsize, fontname=axlabelfontname)
- general.get_figure(show, r, figtype, 'biplot_3d', theme)
|