visualize_5_30.py 1.1 KB

1234567891011121314151617181920
  1. def regplot(df="dataframe", x=None, y=None, yhat=None, dim=(6, 4), colordot='#4a4e4d', colorline='#fe8a71', r=300,
  2. ar=0, dotsize=6, valphaline=1, valphadot=1, linewidth=1, markerdot="o", show=False, axtickfontsize=9,
  3. axtickfontname="Arial", axlabelfontsize=9, axlabelfontname="Arial", ylm=None, xlm=None, axxlabel=None,
  4. axylabel=None, figtype='png', theme=None):
  5. if theme == 'dark':
  6. general.dark_bg()
  7. fig, ax = plt.subplots(figsize=dim)
  8. plt.scatter(df[x].to_numpy(), df[y].to_numpy(), color=colordot, s=dotsize, alpha=valphadot, marker=markerdot,
  9. label='Observed data')
  10. plt.plot(df[x].to_numpy(), df[yhat].to_numpy(), color=colorline, linewidth=linewidth, alpha=valphaline,
  11. label='Regression line')
  12. if axxlabel:
  13. x = axxlabel
  14. if axylabel:
  15. y = axylabel
  16. general.axis_labels(x, y, axlabelfontsize, axlabelfontname)
  17. general.axis_ticks(xlm, ylm, axtickfontsize, axtickfontname, ar)
  18. plt.legend(fontsize=9)
  19. general.get_figure(show, r, figtype, 'reg_plot', theme)