visualize_5_38.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. def roc(fpr=None, tpr=None, c_line_style='-', c_line_color='#f05f21', c_line_width=1, diag_line=True,
  2. diag_line_style='--', diag_line_width=1, diag_line_color='b', auc=None, shade_auc=False,
  3. shade_auc_color='#f48d60',
  4. axxlabel='False Positive Rate (1 - Specificity)', axylabel='True Positive Rate (Sensitivity)', ar=(0, 0),
  5. axtickfontsize=9, axtickfontname='Arial', axlabelfontsize=9, axlabelfontname='Arial',
  6. plotlegend=True, legendpos='lower right', legendanchor=None, legendcols=1, legendfontsize=8,
  7. legendlabelframe=False, legend_columnspacing=None, per_class=False, dim=(6, 5), show=False, figtype='png',
  8. figname='roc', r=300, ylm=None, theme=None):
  9. if theme == 'dark':
  10. general.dark_bg()
  11. plt.subplots(figsize=dim)
  12. # plt.margins(x=0)
  13. if auc:
  14. plt.plot(fpr, tpr, color=c_line_color, linestyle=c_line_style, linewidth=c_line_width,
  15. label='AUC = %0.4f' % auc)
  16. else:
  17. plt.plot(fpr, tpr, color=c_line_color, linestyle=c_line_style, linewidth=c_line_width)
  18. if diag_line:
  19. plt.plot([0, 1], [0, 1], color=diag_line_color, linestyle=diag_line_style, linewidth=diag_line_width,
  20. label='Chance level')
  21. if per_class:
  22. plt.plot([0, 0], [0, 1], color='grey', linestyle='-', linewidth=1)
  23. plt.plot([0, 1], [1, 1], color='grey', linestyle='-', linewidth=1, label='Perfect performance')
  24. # ylm must be tuple of start, end, interval
  25. if ylm:
  26. plt.ylim(bottom=ylm[0], top=ylm[1])
  27. plt.yticks(np.arange(ylm[0], ylm[1], ylm[2]), fontsize=axtickfontsize, fontname=axtickfontname)
  28. plt.yticks(fontsize=axtickfontsize, rotation=ar[1], fontname=axtickfontname)
  29. if axxlabel:
  30. _x = axxlabel
  31. if axylabel:
  32. _y = axylabel
  33. if shade_auc:
  34. plt.fill_between(x=fpr, y1=tpr, color=shade_auc_color)
  35. if plotlegend:
  36. plt.legend(loc=legendpos, bbox_to_anchor=legendanchor, ncol=legendcols, fontsize=legendfontsize,
  37. frameon=legendlabelframe, columnspacing=legend_columnspacing)
  38. general.axis_labels(_x, _y, axlabelfontsize, axlabelfontname)
  39. general.get_figure(show, r, figtype, figname, theme)