visualize_5_42.py 1.7 KB

123456789101112131415161718192021222324252627282930
  1. def tsneplot(score=None, axlabelfontsize=9, axlabelfontname="Arial", figtype='png', r=300, show=False,
  2. markerdot="o", dotsize=6, valphadot=1, colordot='#4a4e4d', colorlist=None, legendpos='best',
  3. figname='tsne_2d', dim=(6, 4), legendanchor=None, theme=None):
  4. assert score is not None, "score are missing"
  5. if theme == 'dark':
  6. general.dark_bg()
  7. plt.subplots(figsize=dim)
  8. if colorlist is not None:
  9. unique_class = set(colorlist)
  10. # color_dict = dict()
  11. assign_values = {col: i for i, col in enumerate(unique_class)}
  12. color_result_num = [assign_values[i] for i in colorlist]
  13. if colordot and isinstance(colordot, (tuple, list)):
  14. colour_map = ListedColormap(colordot)
  15. s = plt.scatter(score[:, 0], score[:, 1], c=color_result_num, cmap=colour_map,
  16. s=dotsize, alpha=valphadot, marker=markerdot)
  17. plt.legend(handles=s.legend_elements()[0], labels=list(unique_class), loc=legendpos,
  18. bbox_to_anchor=legendanchor)
  19. elif colordot and not isinstance(colordot, (tuple, list)):
  20. s = plt.scatter(score[:, 0], score[:, 1], c=color_result_num,
  21. s=dotsize, alpha=valphadot, marker=markerdot)
  22. plt.legend(handles=s.legend_elements()[0], labels=list(unique_class), loc=legendpos,
  23. bbox_to_anchor=legendanchor)
  24. else:
  25. plt.scatter(score[:, 0], score[:, 1], color=colordot,
  26. s=dotsize, alpha=valphadot, marker=markerdot)
  27. plt.xlabel("t-SNE-1", fontsize=axlabelfontsize, fontname=axlabelfontname)
  28. plt.ylabel("t-SNE-2", fontsize=axlabelfontsize, fontname=axlabelfontname)
  29. general.get_figure(show, r, figtype, figname, theme)