123456789101112131415161718192021222324252627282930 |
- def tsneplot(score=None, axlabelfontsize=9, axlabelfontname="Arial", figtype='png', r=300, show=False,
- markerdot="o", dotsize=6, valphadot=1, colordot='#4a4e4d', colorlist=None, legendpos='best',
- figname='tsne_2d', dim=(6, 4), legendanchor=None, theme=None):
- assert score is not None, "score are missing"
- if theme == 'dark':
- general.dark_bg()
- plt.subplots(figsize=dim)
- if colorlist is not None:
- unique_class = set(colorlist)
- # color_dict = dict()
- assign_values = {col: i for i, col in enumerate(unique_class)}
- color_result_num = [assign_values[i] for i in colorlist]
- if colordot and isinstance(colordot, (tuple, list)):
- colour_map = ListedColormap(colordot)
- s = plt.scatter(score[:, 0], score[:, 1], c=color_result_num, cmap=colour_map,
- s=dotsize, alpha=valphadot, marker=markerdot)
- plt.legend(handles=s.legend_elements()[0], labels=list(unique_class), loc=legendpos,
- bbox_to_anchor=legendanchor)
- elif colordot and not isinstance(colordot, (tuple, list)):
- s = plt.scatter(score[:, 0], score[:, 1], c=color_result_num,
- s=dotsize, alpha=valphadot, marker=markerdot)
- plt.legend(handles=s.legend_elements()[0], labels=list(unique_class), loc=legendpos,
- bbox_to_anchor=legendanchor)
- else:
- plt.scatter(score[:, 0], score[:, 1], color=colordot,
- s=dotsize, alpha=valphadot, marker=markerdot)
- plt.xlabel("t-SNE-1", fontsize=axlabelfontsize, fontname=axlabelfontname)
- plt.ylabel("t-SNE-2", fontsize=axlabelfontsize, fontname=axlabelfontname)
- general.get_figure(show, r, figtype, figname, theme)
|