visualize_5_41.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. def biplot(cscore=None, loadings=None, labels=None, var1=None, var2=None, var3=None, axlabelfontsize=9, axlabelfontname="Arial",
  2. figtype='png', r=300, show=False, markerdot="o", dotsize=6, valphadot=1, colordot='#eba487', arrowcolor='#87ceeb',
  3. valphaarrow=1, arrowlinestyle='-', arrowlinewidth=0.5, centerlines=True, colorlist=None, legendpos='best',
  4. datapoints=True, dim=(6, 4), theme=None):
  5. if theme == 'dark':
  6. general.dark_bg()
  7. assert cscore is not None and loadings is not None and labels is not None and var1 is not None and var2 is not None, \
  8. "cscore or loadings or labels or var1 or var2 are missing"
  9. if var1 is not None and var2 is not None and var3 is None:
  10. xscale = 1.0 / (cscore[:, 0].max() - cscore[:, 0].min())
  11. yscale = 1.0 / (cscore[:, 1].max() - cscore[:, 1].min())
  12. # zscale = 1.0 / (cscore[:, 2].max() - cscore[:, 2].min())
  13. # colorlist is an array of classes from dataframe column
  14. plt.subplots(figsize=dim)
  15. if datapoints:
  16. if colorlist is not None:
  17. unique_class = set(colorlist)
  18. # color_dict = dict()
  19. assign_values = {col: i for i, col in enumerate(unique_class)}
  20. color_result_num = [assign_values[i] for i in colorlist]
  21. if colordot and isinstance(colordot, (tuple, list)):
  22. colour_map = ListedColormap(colordot)
  23. # for i in range(len(list(unique_class))):
  24. # color_dict[list(unique_class)[i]] = colordot[i]
  25. # color_result = [color_dict[i] for i in colorlist]
  26. s = plt.scatter(cscore[:, 0] * xscale, cscore[:, 1] * yscale, c=color_result_num, cmap=colour_map,
  27. s=dotsize, alpha=valphadot, marker=markerdot)
  28. plt.legend(handles=s.legend_elements()[0], labels=list(unique_class), loc=legendpos)
  29. elif colordot and not isinstance(colordot, (tuple, list)):
  30. # s = plt.scatter(cscore[:, 0] * xscale, cscore[:, 1] * yscale, color=color_result, s=dotsize,
  31. # alpha=valphadot, marker=markerdot)
  32. # plt.legend(handles=s.legend_elements()[0], labels=list(unique_class))
  33. s = plt.scatter(cscore[:, 0] * xscale, cscore[:, 1] * yscale, c=color_result_num, s=dotsize,
  34. alpha=valphadot, marker=markerdot)
  35. plt.legend(handles=s.legend_elements()[0], labels=list(unique_class), loc=legendpos)
  36. else:
  37. plt.scatter(cscore[:, 0] * xscale, cscore[:, 1] * yscale, color=colordot, s=dotsize,
  38. alpha=valphadot, marker=markerdot)
  39. if centerlines:
  40. plt.axhline(y=0, linestyle='--', color='#7d7d7d', linewidth=1)
  41. plt.axvline(x=0, linestyle='--', color='#7d7d7d', linewidth=1)
  42. # loadings[0] is the number of the original variables
  43. # this is important where variables more than number of observations
  44. for i in range(len(loadings[0])):
  45. plt.arrow(0, 0, loadings[0][i], loadings[1][i], color=arrowcolor, alpha=valphaarrow, ls=arrowlinestyle,
  46. lw=arrowlinewidth)
  47. plt.text(loadings[0][i], loadings[1][i], labels[i])
  48. # adjust_text(t)
  49. # plt.xlim(min(loadings[0]) - 0.1, max(loadings[0]) + 0.1)
  50. # plt.ylim(min(loadings[1]) - 0.1, max(loadings[1]) + 0.1)
  51. xlimit_max = np.max([np.max(cscore[:, 0]*xscale), np.max(loadings[0])])
  52. xlimit_min = np.min([np.min(cscore[:, 0]*xscale), np.min(loadings[0])])
  53. ylimit_max = np.max([np.max(cscore[:, 1]*yscale), np.max(loadings[1])])
  54. ylimit_min = np.min([np.min(cscore[:, 1]*yscale), np.min(loadings[1])])
  55. plt.xlim(xlimit_min-0.2, xlimit_max+0.2)
  56. plt.ylim(ylimit_min-0.2, ylimit_max+0.2)
  57. general.axis_labels("PC1 ({}%)".format(var1), "PC2 ({}%)".format(var2), axlabelfontsize, axlabelfontname)
  58. general.get_figure(show, r, figtype, 'biplot_2d', theme)
  59. # 3D
  60. if var1 is not None and var2 is not None and var3 is not None:
  61. xscale = 1.0 / (cscore[:, 0].max() - cscore[:, 0].min())
  62. yscale = 1.0 / (cscore[:, 1].max() - cscore[:, 1].min())
  63. zscale = 1.0 / (cscore[:, 2].max() - cscore[:, 2].min())
  64. fig = plt.figure(figsize=dim)
  65. ax = fig.add_subplot(111, projection='3d')
  66. if datapoints:
  67. if colorlist is not None:
  68. unique_class = set(colorlist)
  69. assign_values = {col: i for i, col in enumerate(unique_class)}
  70. color_result_num = [assign_values[i] for i in colorlist]
  71. if colordot and isinstance(colordot, (tuple, list)):
  72. colour_map = ListedColormap(colordot)
  73. s = ax.scatter(cscore[:, 0]*xscale, cscore[:, 1]*yscale, cscore[:, 2]*zscale, c=color_result_num,
  74. cmap=colour_map, s=dotsize, alpha=valphadot, marker=markerdot)
  75. plt.legend(handles=s.legend_elements()[0], labels=list(unique_class), loc=legendpos)
  76. elif colordot and not isinstance(colordot, (tuple, list)):
  77. s = ax.scatter(cscore[:, 0]*xscale, cscore[:, 1]*yscale, cscore[:, 2]*zscale, c=color_result_num,
  78. s=dotsize, alpha=valphadot, marker=markerdot)
  79. plt.legend(handles=s.legend_elements()[0], labels=list(unique_class), loc=legendpos)
  80. else:
  81. ax.scatter(cscore[:, 0] * xscale, cscore[:, 1] * yscale, cscore[:, 2] * zscale, color=colordot,
  82. s=dotsize, alpha=valphadot, marker=markerdot)
  83. for i in range(len(loadings[0])):
  84. ax.quiver(0, 0, 0, loadings[0][i], loadings[1][i], loadings[2][i], color=arrowcolor, alpha=valphaarrow,
  85. ls=arrowlinestyle, lw=arrowlinewidth)
  86. ax.text(loadings[0][i], loadings[1][i], loadings[2][i], labels[i])
  87. xlimit_max = np.max([np.max(cscore[:, 0] * xscale), np.max(loadings[0])])
  88. xlimit_min = np.min([np.min(cscore[:, 0] * xscale), np.min(loadings[0])])
  89. ylimit_max = np.max([np.max(cscore[:, 1] * yscale), np.max(loadings[1])])
  90. ylimit_min = np.min([np.min(cscore[:, 1] * yscale), np.min(loadings[1])])
  91. zlimit_max = np.max([np.max(cscore[:, 2] * zscale), np.max(loadings[2])])
  92. zlimit_min = np.min([np.min(cscore[:, 2] * zscale), np.min(loadings[2])])
  93. # ax.set_xlim(min(loadings[0])-0.1, max(loadings[0])+0.1)
  94. # ax.set_ylim(min(loadings[1])-0.1, max(loadings[1])+0.1)
  95. # ax.set_zlim(min(loadings[2])-0.1, max(loadings[2])+0.1)
  96. ax.set_xlim(xlimit_min-0.2, xlimit_max+0.2)
  97. ax.set_ylim(ylimit_min-0.2, ylimit_max+0.2)
  98. ax.set_zlim(zlimit_min-0.2, zlimit_max+0.2)
  99. ax.set_xlabel("PC1 ({}%)".format(var1), fontsize=axlabelfontsize, fontname=axlabelfontname)
  100. ax.set_ylabel("PC2 ({}%)".format(var2), fontsize=axlabelfontsize, fontname=axlabelfontname)
  101. ax.set_zlabel("PC3 ({}%)".format(var3), fontsize=axlabelfontsize, fontname=axlabelfontname)
  102. general.get_figure(show, r, figtype, 'biplot_3d', theme)