visualize_5_35.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. def singlebar(df='dataframe', dim=(6, 4), bw=0.4, colorbar='#f2aa4cff', hbsize=4, r=300, ar=(0, 0), valphabar=1,
  2. errorbar=True, show=False, ylm=None, axtickfontsize=9, axtickfontname='Arial', ax_x_ticklabel=None,
  3. axlabelfontsize=9, axlabelfontname='Arial', yerrlw=None, yerrcw=None, axxlabel=None, axylabel=None,
  4. figtype='png', add_sign_line=False, pv=None,
  5. sign_line_opts={'symbol': '*', 'fontsize': 9, 'linewidth': 0.5, 'arrowstyle': '-', 'fontname':'Arial'},
  6. sign_line_pvals=False,
  7. add_sign_symbol=False, sign_symbol_opts={'symbol': '*', 'fontsize': 9, 'rotation':0, 'fontname':'Arial'},
  8. sign_line_pairs=None, sub_cat=None, sub_cat_opts={'y_neg_dist': 3.5, 'fontsize': 9, 'fontname':'Arial'},
  9. sub_cat_label_dist=None, symb_dist=None, group_let=None, df_format=None, samp_col_name=None,
  10. col_order=False, dotplot=False, dotsize=6, colordot=['#101820ff'], valphadot=1, markerdot='o',
  11. sign_line_pairs_dist=None, sign_line_pv_symb_dist=None, div_fact=20, add_text=None,
  12. figname='singlebar', connectionstyle='bar, armA=50, armB=50, angle=180, fraction=0',
  13. std_errs_vis='both', yerrzorder=8, theme=None):
  14. plt.rcParams['mathtext.fontset'] = 'custom'
  15. plt.rcParams['mathtext.default'] = 'regular'
  16. plt.rcParams['mathtext.it'] = 'Arial:italic'
  17. plt.rcParams['mathtext.bf'] = 'Arial:italic:bold'
  18. # set axis labels to None
  19. _x = None
  20. _y = None
  21. if df_format == 'stack':
  22. # sample_list = df[samp_col_name].unique()
  23. if samp_col_name is None:
  24. raise ValueError('sample column name required')
  25. df_mean = df.groupby(samp_col_name).mean().reset_index().set_index(samp_col_name).T
  26. df_sem = df.groupby(samp_col_name).sem().reset_index().set_index(samp_col_name).T
  27. if col_order:
  28. df_mean = df_mean[df[samp_col_name].unique()]
  29. df_sem = df_sem[df[samp_col_name].unique()]
  30. bar_h = df_mean.iloc[0]
  31. bar_se = df_sem.iloc[0]
  32. sample_list = df_mean.columns.to_numpy()
  33. # get minimum from df
  34. min_value = (0, df_mean.iloc[0].min())[df_mean.iloc[0].min() < 0]
  35. else:
  36. bar_h = df.describe().loc['mean']
  37. bar_se = df.sem()
  38. bar_counts = df.describe().loc['count']
  39. sample_list = df.columns.to_numpy()
  40. min_value = (0, min(df.min()))[min(df.min()) < 0]
  41. if std_errs_vis == 'upper':
  42. std_errs_vis = [len(bar_se)*[0], bar_se]
  43. elif std_errs_vis == 'lower':
  44. std_errs_vis = [bar_se, len(bar_se)*[0]]
  45. elif std_errs_vis == 'both':
  46. std_errs_vis = bar_se
  47. else:
  48. raise ValueError('In valid value for the std_errs_vis')
  49. xbar = np.arange(len(sample_list))
  50. color_list_bar = colorbar
  51. if theme == 'dark':
  52. general.dark_bg()
  53. plt.subplots(figsize=dim)
  54. if errorbar:
  55. plt.bar(x=xbar, height=bar_h, yerr=std_errs_vis, width=bw, color=color_list_bar,
  56. capsize=hbsize, alpha=valphabar, zorder=5, error_kw={'elinewidth': yerrlw, 'capthick': yerrcw,
  57. 'zorder': yerrzorder})
  58. else:
  59. plt.bar(x=xbar, height=bar_h, width=bw, color=color_list_bar, capsize=hbsize, alpha=valphabar)
  60. if ax_x_ticklabel:
  61. x_ticklabel = ax_x_ticklabel
  62. else:
  63. x_ticklabel = sample_list
  64. plt.xticks(ticks=xbar, labels=x_ticklabel, fontsize=axtickfontsize, rotation=ar[0], fontname=axtickfontname)
  65. if axxlabel:
  66. _x = axxlabel
  67. if axylabel:
  68. _y = axylabel
  69. general.axis_labels(_x, _y, axlabelfontsize, axlabelfontname)
  70. # ylm must be tuple of start, end, interval
  71. if ylm:
  72. plt.ylim(bottom=ylm[0], top=ylm[1])
  73. plt.yticks(np.arange(ylm[0], ylm[1], ylm[2]), fontsize=axtickfontsize, fontname=axtickfontname)
  74. plt.yticks(fontsize=axtickfontsize, rotation=ar[1], fontname=axtickfontname)
  75. color_list_dot = colordot
  76. if len(color_list_dot) == 1:
  77. color_list_dot = colordot * len(sample_list)
  78. # checked for unstacked data
  79. if dotplot:
  80. for cols in range(len(sample_list)):
  81. plt.scatter(
  82. x=np.linspace(xbar[cols] - bw / 2, xbar[cols] + bw / 2, int(bar_counts[cols])),
  83. y=df[df.columns[cols]].dropna(), s=dotsize, color=color_list_dot[cols], zorder=10, alpha=valphadot,
  84. marker=markerdot)
  85. size_factor_to_start_line = max(bar_h) / div_fact
  86. # for only adjacent bars (not for multiple bars with single control)
  87. if add_sign_line:
  88. for i in xbar:
  89. if i % 2 != 0:
  90. continue
  91. x_pos = xbar[i]
  92. x_pos_2 = xbar[i+1]
  93. y_pos = df.describe().loc['mean'].to_numpy()[i] + df.sem().to_numpy()[i]
  94. y_pos_2 = df.describe().loc['mean'].to_numpy()[i+1] + df.sem().to_numpy()[i+1]
  95. # only if y axis is positive; in future make a function to call it (2 times used)
  96. if y_pos > 0:
  97. y_pos += size_factor_to_start_line
  98. y_pos_2 += size_factor_to_start_line
  99. pv_symb = general.pvalue_symbol(pv[int(i/2)], sign_line_opts['symbol'])
  100. if pv_symb:
  101. plt.annotate('', xy=(x_pos, max(y_pos, y_pos_2)), xytext=(x_pos_2, max(y_pos, y_pos_2)),
  102. arrowprops={'connectionstyle': connectionstyle,
  103. 'arrowstyle': sign_line_opts['arrowstyle'],
  104. 'linewidth': sign_line_opts['linewidth']})
  105. plt.annotate(pv_symb, xy=(np.mean([x_pos, x_pos_2]), max(y_pos, y_pos_2) +
  106. sign_line_opts['dist_y_pos']),
  107. fontsize=sign_line_opts['fontsize'], ha="center")
  108. # for only adjacent bars with one control but multiple treatments
  109. # need to work for sign_line_pairs (update df on line 1276)
  110. p_index = 0
  111. y_pos_dict = dict()
  112. y_pos_dict_trt = dict()
  113. if sign_line_pairs:
  114. for i in sign_line_pairs:
  115. y_pos_adj = 0
  116. x_pos = xbar[i[0]]
  117. x_pos_2 = xbar[i[1]]
  118. y_pos = df.describe().loc['mean'].to_numpy()[i[0]] + df.sem().to_numpy()[i[0]]
  119. y_pos_2 = df.describe().loc['mean'].to_numpy()[i[1]] + df.sem().to_numpy()[i[1]]
  120. # only if y axis is positive; in future make a function to call it (2 times used)
  121. if y_pos > 0:
  122. y_pos += size_factor_to_start_line/2
  123. y_pos_2 += size_factor_to_start_line/2
  124. # check if the mean of y_pos is not lesser than not other treatments which lies between
  125. # eg if 0-1 has higher sign bar than the 0-2
  126. if i[0] in y_pos_dict_trt:
  127. y_pos_adj = 1
  128. if y_pos_2 <= y_pos_dict_trt[i[0]][1]:
  129. if sign_line_pairs_dist:
  130. y_pos_2 += (y_pos_dict_trt[i[0]][1] - y_pos_2) + (3 * size_factor_to_start_line) + \
  131. sign_line_pairs_dist[p_index]
  132. else:
  133. y_pos_2 += (y_pos_dict_trt[i[0]][1] - y_pos_2) + (3 * size_factor_to_start_line)
  134. elif y_pos <= y_pos_dict_trt[i[0]][0]:
  135. if sign_line_pairs_dist:
  136. y_pos += 3 * size_factor_to_start_line + sign_line_pairs_dist[p_index]
  137. else:
  138. y_pos += 3 * size_factor_to_start_line
  139. # check if difference is not equivalent between two y_pos
  140. # if yes add some distance, so that sign bar will not overlap
  141. if i[0] in y_pos_dict:
  142. y_pos_adj = 1
  143. if 0.75 < df.describe().loc['mean'].to_numpy()[i[0]]/df.describe().loc['mean'].to_numpy()[i[1]] < 1.25:
  144. if sign_line_pairs_dist:
  145. y_pos += 2 * size_factor_to_start_line + sign_line_pairs_dist[p_index]
  146. else:
  147. y_pos += 2 * size_factor_to_start_line
  148. if y_pos_adj == 0 and sign_line_pairs_dist:
  149. if y_pos >= y_pos_2:
  150. y_pos += sign_line_pairs_dist[p_index]
  151. else:
  152. y_pos_2 += sign_line_pairs_dist[p_index]
  153. # sign_line_pvals passed, used p values instead of symbols
  154. if sign_line_pvals:
  155. pv_symb = '$\it{p}$'+ str(pv[p_index])
  156. else:
  157. pv_symb = general.pvalue_symbol(pv[p_index], sign_line_opts['symbol'])
  158. y_pos_dict[i[0]] = y_pos
  159. y_pos_dict_trt[i[0]] = [y_pos, y_pos_2]
  160. if pv_symb:
  161. plt.annotate('', xy=(x_pos, max(y_pos, y_pos_2)), xytext=(x_pos_2, max(y_pos, y_pos_2)),
  162. arrowprops={'connectionstyle': connectionstyle,
  163. 'arrowstyle': sign_line_opts['arrowstyle'],
  164. 'linewidth': sign_line_opts['linewidth']})
  165. # here size factor size_factor_to_start_line added instead of sign_line_opts['dist_y_pos']
  166. # make this change everywhere in future release
  167. plt.annotate(pv_symb, xy=(np.mean([x_pos, x_pos_2]), max(y_pos, y_pos_2) +
  168. size_factor_to_start_line + sign_line_pv_symb_dist[p_index]),
  169. fontsize=sign_line_opts['fontsize'], ha="center")
  170. p_index += 1
  171. if add_sign_symbol:
  172. for i in xbar:
  173. x_pos = xbar[i]
  174. # y_pos = df.describe().loc['mean'].to_numpy()[i] + df.sem().to_numpy()[i] + size_factor_to_start_line
  175. if symb_dist:
  176. y_pos = bar_h.to_numpy()[i] + bar_se.to_numpy()[i] + \
  177. size_factor_to_start_line + symb_dist[i]
  178. else:
  179. y_pos = bar_h.to_numpy()[i] + bar_se.to_numpy()[i] + \
  180. size_factor_to_start_line
  181. # group_let list
  182. if isinstance(group_let, list):
  183. if y_pos > 0:
  184. plt.annotate(group_let[i], xy=(x_pos, y_pos),
  185. fontsize=sign_symbol_opts['fontsize'], ha="center",
  186. rotation=sign_symbol_opts['rotation'], fontfamily=sign_symbol_opts['fontname'])
  187. # only if y axis is positive
  188. if pv:
  189. if y_pos > 0:
  190. pv_symb = general.pvalue_symbol(pv[i], sign_symbol_opts['symbol'])
  191. if pv_symb:
  192. plt.annotate(pv_symb, xy=(x_pos, y_pos), fontsize=sign_symbol_opts['fontsize'], ha="center",
  193. rotation=sign_symbol_opts['rotation'], fontfamily=sign_symbol_opts['fontname'])
  194. sub_cat_i = 0
  195. if sub_cat:
  196. if isinstance(sub_cat, dict):
  197. for k in sub_cat:
  198. if isinstance(k, tuple) and len(k) == 2:
  199. cat_x_pos, cat_y_pos, cat_x_pos_2 = k[0], min_value - \
  200. (sub_cat_opts['y_neg_dist']*size_factor_to_start_line), k[1]
  201. plt.annotate('', xy=(cat_x_pos-(bw/2), cat_y_pos), xytext=(cat_x_pos_2+(bw/2), cat_y_pos),
  202. arrowprops={'arrowstyle': '-', 'linewidth': 0.5}, annotation_clip=False)
  203. if sub_cat_label_dist and isinstance(sub_cat_label_dist, list):
  204. plt.annotate(sub_cat[k], xy=(np.mean([cat_x_pos, cat_x_pos_2]),
  205. cat_y_pos - size_factor_to_start_line - sub_cat_label_dist[sub_cat_i]),
  206. ha="center", fontsize=sub_cat_opts['fontsize'], annotation_clip=False,
  207. fontfamily=sub_cat_opts['fontname'])
  208. sub_cat_i += 1
  209. else:
  210. plt.annotate(sub_cat[k], xy=(np.mean([cat_x_pos, cat_x_pos_2]),
  211. cat_y_pos-size_factor_to_start_line),
  212. ha="center", fontsize=sub_cat_opts['fontsize'], annotation_clip=False,
  213. fontfamily=sub_cat_opts['fontname'])
  214. else:
  215. raise KeyError("Sub category keys must be tuple of size 2")
  216. if isinstance(add_text, list):
  217. plt.text(add_text[0], add_text[1], add_text[2], fontsize=9, fontfamily='Arial')
  218. general.get_figure(show, r, figtype, figname, theme)