1234567891011121314151617181920 |
- def regplot(df="dataframe", x=None, y=None, yhat=None, dim=(6, 4), colordot='#4a4e4d', colorline='#fe8a71', r=300,
- ar=0, dotsize=6, valphaline=1, valphadot=1, linewidth=1, markerdot="o", show=False, axtickfontsize=9,
- axtickfontname="Arial", axlabelfontsize=9, axlabelfontname="Arial", ylm=None, xlm=None, axxlabel=None,
- axylabel=None, figtype='png', theme=None):
- if theme == 'dark':
- general.dark_bg()
- fig, ax = plt.subplots(figsize=dim)
- plt.scatter(df[x].to_numpy(), df[y].to_numpy(), color=colordot, s=dotsize, alpha=valphadot, marker=markerdot,
- label='Observed data')
- plt.plot(df[x].to_numpy(), df[yhat].to_numpy(), color=colorline, linewidth=linewidth, alpha=valphaline,
- label='Regression line')
- if axxlabel:
- x = axxlabel
- if axylabel:
- y = axylabel
- general.axis_labels(x, y, axlabelfontsize, axlabelfontname)
- general.axis_ticks(xlm, ylm, axtickfontsize, axtickfontname, ar)
- plt.legend(fontsize=9)
- general.get_figure(show, r, figtype, 'reg_plot', theme)
|