visualize_5.py 116 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960
  1. """
  2. visuz module implements visualization functions related to Bioinformatics, Statistics and Machine learning:
  3. Gene expression data visualization
  4. Molecular marker data visualization
  5. Statistical and Machine learning visualization
  6. """
  7. import pandas as pd
  8. import matplotlib.pyplot as plt
  9. import numpy as np
  10. import seaborn as sns
  11. from matplotlib_venn import venn3, venn2
  12. from random import sample
  13. from functools import reduce
  14. import sys
  15. from matplotlib.colors import ListedColormap
  16. __all__ = ['GeneExpression', 'General', 'gene_exp', 'general', 'marker', 'marker', 'stat', 'cluster']
  17. def venn(vennset=(1,1,1,1,1,1,1), venncolor=('#00909e', '#f67280', '#ff971d'), vennalpha=0.5,
  18. vennlabel=('A', 'B', 'C')):
  19. fig = plt.figure()
  20. if len(vennset) == 7:
  21. venn3(subsets=vennset, set_labels=vennlabel, set_colors=venncolor, alpha=vennalpha)
  22. plt.savefig('venn3.png', format='png', bbox_inches='tight', dpi=300)
  23. elif len(vennset) == 3:
  24. venn2(subsets=vennset, set_labels=vennlabel, set_colors=venncolor, alpha=vennalpha)
  25. plt.savefig('venn2.png', format='png', bbox_inches='tight', dpi=300)
  26. else:
  27. print("Error: check the set dataset")
  28. class GeneExpression:
  29. def __init__(self):
  30. pass
  31. @staticmethod
  32. def gene_plot(d, geneid, lfc, lfc_thr, pv_thr, genenames, gfont, pv, gstyle):
  33. if genenames is not None and genenames == "deg":
  34. for i in d[geneid].unique():
  35. if (d.loc[d[geneid] == i, lfc].iloc[0] >= lfc_thr[0] and d.loc[d[geneid] == i, pv].iloc[0] < pv_thr[0]) or \
  36. (d.loc[d[geneid] == i, lfc].iloc[0] <= -lfc_thr[1] and d.loc[d[geneid] == i, pv].iloc[0] < pv_thr[1]):
  37. if gstyle == 1:
  38. plt.text(d.loc[d[geneid] == i, lfc].iloc[0], d.loc[d[geneid] == i, 'logpv_add_axy'].iloc[0], i,
  39. fontsize=gfont)
  40. elif gstyle == 2:
  41. plt.annotate(i, xy=(d.loc[d[geneid] == i, lfc].iloc[0], d.loc[d[geneid] == i, 'logpv_add_axy'].iloc[0]),
  42. xycoords='data', xytext=(5, -15), textcoords='offset points', size=6,
  43. bbox=dict(boxstyle="round", alpha=0.1),
  44. arrowprops=dict(arrowstyle="wedge,tail_width=0.5", alpha=0.1, relpos=(0, 0)))
  45. else:
  46. print("Error: invalid gstyle choice")
  47. sys.exit(1)
  48. elif genenames is not None and type(genenames) is tuple:
  49. for i in d[geneid].unique():
  50. if i in genenames:
  51. if gstyle == 1:
  52. plt.text(d.loc[d[geneid] == i, lfc].iloc[0], d.loc[d[geneid] == i, 'logpv_add_axy'].iloc[0], i,
  53. fontsize=gfont)
  54. elif gstyle == 2:
  55. plt.annotate(i, xy=(d.loc[d[geneid] == i, lfc].iloc[0], d.loc[d[geneid] == i, 'logpv_add_axy'].iloc[0]),
  56. xycoords='data', xytext=(5, -15), textcoords='offset points', size=6,
  57. bbox=dict(boxstyle="round", alpha=0.1),
  58. arrowprops=dict(arrowstyle="wedge,tail_width=0.5", alpha=0.1, relpos=(0, 0)))
  59. else:
  60. print("Error: invalid gstyle choice")
  61. sys.exit(1)
  62. elif genenames is not None and type(genenames) is dict:
  63. for i in d[geneid].unique():
  64. if i in genenames:
  65. if gstyle == 1:
  66. plt.text(d.loc[d[geneid] == i, lfc].iloc[0], d.loc[d[geneid] == i, 'logpv_add_axy'].iloc[0],
  67. genenames[i], fontsize=gfont)
  68. elif gstyle == 2:
  69. plt.annotate(genenames[i], xy=(d.loc[d[geneid] == i, lfc].iloc[0], d.loc[d[geneid] == i, 'logpv_add_axy'].iloc[0]),
  70. xycoords='data', xytext=(5, -15), textcoords='offset points', size=6,
  71. bbox=dict(boxstyle="round", alpha=0.1),
  72. arrowprops=dict(arrowstyle="wedge,tail_width=0.5", alpha=0.1, relpos=(0, 0)))
  73. else:
  74. print("Error: invalid gstyle choice")
  75. sys.exit(1)
  76. @staticmethod
  77. def geneplot_ma(df, geneid, lfc, lfc_thr, genenames, gfont, gstyle):
  78. if genenames is not None and genenames == "deg":
  79. for i in df[geneid].unique():
  80. if df.loc[df[geneid] == i, lfc].iloc[0] >= lfc_thr[0] or \
  81. df.loc[df[geneid] == i, lfc].iloc[0] <= -lfc_thr[1]:
  82. if gstyle == 1:
  83. plt.text(df.loc[df[geneid] == i, 'A_add_axy'].iloc[0], df.loc[df[geneid] == i, lfc].iloc[0], i,
  84. fontsize=gfont)
  85. elif gstyle == 2:
  86. plt.annotate(i, xy=(df.loc[df[geneid] == i, 'A_add_axy'].iloc[0],
  87. df.loc[df[geneid] == i, lfc].iloc[0]),
  88. xycoords='data', xytext=(5, -15), textcoords='offset points', size=6,
  89. bbox=dict(boxstyle="round", alpha=0.1),
  90. arrowprops=dict(arrowstyle="wedge,tail_width=0.5", alpha=0.1, relpos=(0, 0)))
  91. else:
  92. print("Error: invalid gstyle choice")
  93. sys.exit(1)
  94. elif genenames is not None and type(genenames) is tuple:
  95. for i in df[geneid].unique():
  96. if i in genenames:
  97. if gstyle == 1:
  98. plt.text(df.loc[df[geneid] == i, 'A_add_axy'].iloc[0], df.loc[df[geneid] == i, lfc].iloc[0], i,
  99. fontsize=gfont)
  100. elif gstyle == 2:
  101. plt.annotate(i, xy=(df.loc[df[geneid] == i, 'A_add_axy'].iloc[0],
  102. df.loc[df[geneid] == i, lfc].iloc[0]),
  103. xycoords='data', xytext=(5, -15), textcoords='offset points', size=6,
  104. bbox=dict(boxstyle="round", alpha=0.1),
  105. arrowprops=dict(arrowstyle="wedge,tail_width=0.5", alpha=0.1, relpos=(0, 0)))
  106. else:
  107. print("Error: invalid gstyle choice")
  108. sys.exit(1)
  109. elif genenames is not None and type(genenames) is dict:
  110. for i in df[geneid].unique():
  111. if i in genenames:
  112. if gstyle == 1:
  113. plt.text(df.loc[df[geneid] == i, 'A_add_axy'].iloc[0], df.loc[df[geneid] == i, lfc].iloc[0],
  114. genenames[i], fontsize=gfont)
  115. elif gstyle == 2:
  116. plt.annotate(genenames[i], xy=(df.loc[df[geneid] == i, 'A_add_axy'].iloc[0],
  117. df.loc[df[geneid] == i, lfc].iloc[0]),
  118. xycoords='data', xytext=(5, -15), textcoords='offset points', size=6,
  119. bbox=dict(boxstyle="round", alpha=0.1),
  120. arrowprops=dict(arrowstyle="wedge,tail_width=0.5", alpha=0.1, relpos=(0, 0)))
  121. else:
  122. print("Error: invalid gstyle choice")
  123. sys.exit(1)
  124. def volcano(df="dataframe", lfc=None, pv=None, lfc_thr=(1, 1), pv_thr=(0.05, 0.05), color=("green", "grey", "red"),
  125. valpha=1, geneid=None, genenames=None, gfont=8, dim=(5, 5), r=300, ar=90, dotsize=8, markerdot="o",
  126. sign_line=False, gstyle=1, show=False, figtype='png', axtickfontsize=9,
  127. axtickfontname="Arial", axlabelfontsize=9, axlabelfontname="Arial", axxlabel=None,
  128. axylabel=None, xlm=None, ylm=None, plotlegend=False, legendpos='best',
  129. figname='volcano', legendanchor=None,
  130. legendlabels=['significant up', 'not significant', 'significant down'], theme=None):
  131. _x = r'$ log_{2}(Fold Change)$'
  132. _y = r'$ -log_{10}(P-value)$'
  133. color = color
  134. # check if dataframe contains any non-numeric character
  135. assert general.check_for_nonnumeric(df[lfc]) == 0, 'dataframe contains non-numeric values in lfc column'
  136. assert general.check_for_nonnumeric(df[pv]) == 0, 'dataframe contains non-numeric values in pv column'
  137. # this is important to check if color or logpv exists and drop them as if you run multiple times same command
  138. # it may update old instance of df
  139. df = df.drop(['color_add_axy', 'logpv_add_axy'], axis=1, errors='ignore')
  140. assert len(set(color)) == 3, 'unique color must be size of 3'
  141. df.loc[(df[lfc] >= lfc_thr[0]) & (df[pv] < pv_thr[0]), 'color_add_axy'] = color[0] # upregulated
  142. df.loc[(df[lfc] <= -lfc_thr[1]) & (df[pv] < pv_thr[1]), 'color_add_axy'] = color[2] # downregulated
  143. df['color_add_axy'].fillna(color[1], inplace=True) # intermediate
  144. df['logpv_add_axy'] = -(np.log10(df[pv]))
  145. # plot
  146. assign_values = {col: i for i, col in enumerate(color)}
  147. color_result_num = [assign_values[i] for i in df['color_add_axy']]
  148. assert len(set(color_result_num)) == 3, \
  149. 'either significant or non-significant genes are missing; try to change lfc_thr or pv_thr to include ' \
  150. 'both significant and non-significant genes'
  151. if theme == 'dark':
  152. general.dark_bg()
  153. plt.subplots(figsize=dim)
  154. if plotlegend:
  155. s = plt.scatter(df[lfc], df['logpv_add_axy'], c=color_result_num, cmap=ListedColormap(color), alpha=valpha,
  156. s=dotsize, marker=markerdot)
  157. assert len(legendlabels) == 3, 'legendlabels must be size of 3'
  158. plt.legend(handles=s.legend_elements()[0], labels=legendlabels, loc=legendpos, bbox_to_anchor=legendanchor)
  159. else:
  160. plt.scatter(df[lfc], df['logpv_add_axy'], c=color_result_num, cmap=ListedColormap(color), alpha=valpha,
  161. s=dotsize, marker=markerdot)
  162. if sign_line:
  163. plt.axhline(y=-np.log10(pv_thr[0]), linestyle='--', color='#7d7d7d', linewidth=1)
  164. plt.axvline(x=lfc_thr[0], linestyle='--', color='#7d7d7d', linewidth=1)
  165. plt.axvline(x=-lfc_thr[1], linestyle='--', color='#7d7d7d', linewidth=1)
  166. GeneExpression.gene_plot(df, geneid, lfc, lfc_thr, pv_thr, genenames, gfont, pv, gstyle)
  167. if axxlabel:
  168. _x = axxlabel
  169. if axylabel:
  170. _y = axylabel
  171. general.axis_labels(_x, _y, axlabelfontsize, axlabelfontname)
  172. general.axis_ticks(xlm, ylm, axtickfontsize, axtickfontname, ar)
  173. general.get_figure(show, r, figtype, figname, theme)
  174. def involcano(df="dataframe", lfc="logFC", pv="p_values", lfc_thr=(1, 1), pv_thr=(0.05, 0.05), color=("green", "grey", "red"),
  175. valpha=1, geneid=None, genenames=None, gfont=8, dim=(5, 5), r=300, ar=90, dotsize=8, markerdot="o",
  176. sign_line=False, gstyle=1, show=False, figtype='png', axtickfontsize=9,
  177. axtickfontname="Arial", axlabelfontsize=9, axlabelfontname="Arial", axxlabel=None,
  178. axylabel=None, xlm=None, ylm=None, plotlegend=False, legendpos='best',
  179. figname='involcano', legendanchor=None, legendlabels=['significant up', 'not significant', 'significant down'],
  180. theme=None):
  181. _x = r'$ log_{2}(Fold Change)$'
  182. _y = r'$ -log_{10}(P-value)$'
  183. color = color
  184. assert general.check_for_nonnumeric(df[lfc]) == 0, 'dataframe contains non-numeric values in lfc column'
  185. assert general.check_for_nonnumeric(df[pv]) == 0, 'dataframe contains non-numeric values in pv column'
  186. # this is important to check if color or logpv exists and drop them as if you run multiple times same command
  187. # it may update old instance of df
  188. df = df.drop(['color_add_axy', 'logpv_add_axy'], axis=1, errors='ignore')
  189. assert len(set(color)) == 3, 'unique color must be size of 3'
  190. df.loc[(df[lfc] >= lfc_thr[0]) & (df[pv] < pv_thr[0]), 'color_add_axy'] = color[0] # upregulated
  191. df.loc[(df[lfc] <= -lfc_thr[1]) & (df[pv] < pv_thr[1]), 'color_add_axy'] = color[2] # downregulated
  192. df['color_add_axy'].fillna(color[1], inplace=True) # intermediate
  193. df['logpv_add_axy'] = -(np.log10(df[pv]))
  194. # plot
  195. assign_values = {col: i for i, col in enumerate(color)}
  196. color_result_num = [assign_values[i] for i in df['color_add_axy']]
  197. assert len(set(color_result_num)) == 3, 'either significant or non-significant genes are missing; try to change lfc_thr or ' \
  198. 'pv_thr to include both significant and non-significant genes'
  199. if theme == 'dark':
  200. general.dark_bg()
  201. plt.subplots(figsize=dim)
  202. if plotlegend:
  203. s = plt.scatter(df[lfc], df['logpv_add_axy'], c=color_result_num, cmap=ListedColormap(color), alpha=valpha,
  204. s=dotsize, marker=markerdot)
  205. assert len(legendlabels) == 3, 'legendlabels must be size of 3'
  206. plt.legend(handles=s.legend_elements()[0], labels=legendlabels, loc=legendpos,
  207. bbox_to_anchor=legendanchor)
  208. else:
  209. plt.scatter(df[lfc], df['logpv_add_axy'], c=color_result_num, cmap=ListedColormap(color), alpha=valpha,
  210. s=dotsize, marker=markerdot)
  211. GeneExpression.gene_plot(df, geneid, lfc, lfc_thr, pv_thr, genenames, gfont, pv, gstyle)
  212. plt.gca().invert_yaxis()
  213. if axxlabel:
  214. _x = axxlabel
  215. if axylabel:
  216. _y = axylabel
  217. general.axis_labels(_x, _y, axlabelfontsize, axlabelfontname)
  218. if xlm:
  219. print('Error: xlm not compatible with involcano')
  220. sys.exit(1)
  221. if ylm:
  222. print('Error: ylm not compatible with involcano')
  223. sys.exit(1)
  224. general.axis_ticks(xlm, ylm, axtickfontsize, axtickfontname, ar)
  225. general.get_figure(show, r, figtype, figname, theme)
  226. @staticmethod
  227. def ma(df="dataframe", lfc=None, ct_count=None, st_count=None, basemean=None, pv=None, lfc_thr=(1, 1), pv_thr=0.05,
  228. valpha=1, dotsize=8,markerdot="o", dim=(6, 5), r=300, show=False, color=("green", "grey", "red"), ar=0,
  229. figtype='png',axtickfontsize=9, axtickfontname="Arial", axlabelfontsize=9, axlabelfontname="Arial",
  230. axxlabel=None, axylabel=None, xlm=None, ylm=None, fclines=False, fclinescolor='#2660a4', legendpos='best',
  231. figname='ma', legendanchor=None, legendlabels=['significant up', 'not significant', 'significant down'],
  232. plotlegend=False, theme=None, geneid=None, genenames=None, gfont=8, gstyle=1, title=None):
  233. _x, _y = 'A', 'M'
  234. assert General.check_for_nonnumeric(df[lfc]) == 0, 'dataframe contains non-numeric values in lfc column'
  235. if ct_count and st_count:
  236. assert General.check_for_nonnumeric(df[ct_count]) == 0, \
  237. 'dataframe contains non-numeric values in ct_count column'
  238. assert General.check_for_nonnumeric(
  239. df[st_count]) == 0, 'dataframe contains non-numeric values in ct_count column'
  240. if basemean:
  241. assert General.check_for_nonnumeric(df[basemean]) == 0, \
  242. 'dataframe contains non-numeric values in basemean column'
  243. # this is important to check if color or A exists and drop them as if you run multiple times same command
  244. # it may update old instance of df
  245. df = df.drop(['color_add_axy', 'A_add_axy'], axis=1, errors='ignore')
  246. assert len(set(color)) == 3, 'unique color must be size of 3'
  247. df.loc[(df[lfc] >= lfc_thr[0]) & (df[pv] < pv_thr), 'color_add_axy'] = color[0] # upregulated
  248. df.loc[(df[lfc] <= -lfc_thr[1]) & (df[pv] < pv_thr), 'color_add_axy'] = color[2] # downregulated
  249. df['color_add_axy'].fillna(color[1], inplace=True) # intermediate
  250. if basemean:
  251. # basemean (mean of normalized counts from DESeq2 results)
  252. df['A_add_axy'] = df[basemean]
  253. else:
  254. df['A_add_axy'] = (np.log2(df[ct_count]) + np.log2(df[st_count])) / 2
  255. # plot
  256. assign_values = {col: i for i, col in enumerate(color)}
  257. color_result_num = [assign_values[i] for i in df['color_add_axy']]
  258. assert len(
  259. set(color_result_num)) == 3, 'either significant or non-significant genes are missing; try to change lfc_thr' \
  260. ' to include both significant and non-significant genes'
  261. if theme:
  262. General.style_bg(theme)
  263. plt.subplots(figsize=dim)
  264. if plotlegend:
  265. s = plt.scatter(df['A_add_axy'], df[lfc], c=color_result_num, cmap=ListedColormap(color),
  266. alpha=valpha, s=dotsize, marker=markerdot)
  267. assert len(legendlabels) == 3, 'legendlabels must be size of 3'
  268. plt.legend(handles=s.legend_elements()[0], labels=legendlabels, loc=legendpos,
  269. bbox_to_anchor=legendanchor)
  270. else:
  271. plt.scatter(df['A_add_axy'], df[lfc], c=color_result_num, cmap=ListedColormap(color),
  272. alpha=valpha, s=dotsize, marker=markerdot)
  273. # draw a central line at M=0
  274. plt.axhline(y=0, color='#7d7d7d', linestyle='--')
  275. # draw lfc threshold lines
  276. if fclines:
  277. plt.axhline(y=lfc_thr[0], color=fclinescolor, linestyle='--')
  278. plt.axhline(y=-lfc_thr[1], color=fclinescolor, linestyle='--')
  279. if axxlabel:
  280. _x = axxlabel
  281. if axylabel:
  282. _y = axylabel
  283. GeneExpression.geneplot_ma(df, geneid, lfc, lfc_thr, genenames, gfont, gstyle)
  284. General.axis_labels(_x, _y, axlabelfontsize, axlabelfontname)
  285. General.axis_ticks(xlm, ylm, axtickfontsize, axtickfontname, ar)
  286. General.get_figure(show, r, figtype, figname, theme, title)
  287. @staticmethod
  288. def hmap(df="dataframe", cmap="seismic", scale=True, dim=(4, 6), rowclus=True, colclus=True, zscore=None, xlabel=True,
  289. ylabel=True, tickfont=(10, 10), r=300, show=False, figtype='png', figname='heatmap', theme=None):
  290. # df = df.set_index(d.columns[0])
  291. # plot heatmap without cluster
  292. # more cmap: https://matplotlib.org/3.1.0/tutorials/colors/colormaps.html
  293. if theme == 'dark':
  294. general.dark_bg()
  295. fig, hm = plt.subplots(figsize=dim)
  296. if rowclus and colclus:
  297. hm = sns.clustermap(df, cmap=cmap, cbar=scale, z_score=zscore, xticklabels=xlabel, yticklabels=ylabel,
  298. figsize=dim)
  299. hm.ax_heatmap.set_xticklabels(hm.ax_heatmap.get_xmajorticklabels(), fontsize=tickfont[0])
  300. hm.ax_heatmap.set_yticklabels(hm.ax_heatmap.get_ymajorticklabels(), fontsize=tickfont[1])
  301. general.get_figure(show, r, figtype, figname, theme)
  302. elif rowclus and colclus is False:
  303. hm = sns.clustermap(df, cmap=cmap, cbar=scale, z_score=zscore, xticklabels=xlabel, yticklabels=ylabel,
  304. figsize=dim, row_cluster=True, col_cluster=False)
  305. hm.ax_heatmap.set_xticklabels(hm.ax_heatmap.get_xmajorticklabels(), fontsize=tickfont[0])
  306. hm.ax_heatmap.set_yticklabels(hm.ax_heatmap.get_ymajorticklabels(), fontsize=tickfont[1])
  307. general.get_figure(show, r, figtype, figname, theme)
  308. elif colclus and rowclus is False:
  309. hm = sns.clustermap(df, cmap=cmap, cbar=scale, z_score=zscore, xticklabels=xlabel, yticklabels=ylabel,
  310. figsize=dim, row_cluster=False, col_cluster=True)
  311. hm.ax_heatmap.set_xticklabels(hm.ax_heatmap.get_xmajorticklabels(), fontsize=tickfont[0])
  312. hm.ax_heatmap.set_yticklabels(hm.ax_heatmap.get_ymajorticklabels(), fontsize=tickfont[1])
  313. general.get_figure(show, r, figtype, figname, theme)
  314. else:
  315. hm = sns.heatmap(df, cmap=cmap, cbar=scale, xticklabels=xlabel, yticklabels=ylabel)
  316. plt.xticks(fontsize=tickfont[0])
  317. plt.yticks(fontsize=tickfont[1])
  318. general.get_figure(show, r, figtype, figname, theme)
  319. class gene_exp:
  320. def __init__(self):
  321. pass
  322. def geneplot(d, geneid, lfc, lfc_thr, pv_thr, genenames, gfont, pv, gstyle):
  323. if genenames is not None and genenames == "deg":
  324. for i in d[geneid].unique():
  325. if (d.loc[d[geneid] == i, lfc].iloc[0] >= lfc_thr[0] and d.loc[d[geneid] == i, pv].iloc[0] < pv_thr[0]) or \
  326. (d.loc[d[geneid] == i, lfc].iloc[0] <= -lfc_thr[1] and d.loc[d[geneid] == i, pv].iloc[0] < pv_thr[1]):
  327. if gstyle==1:
  328. plt.text(d.loc[d[geneid] == i, lfc].iloc[0], d.loc[d[geneid] == i, 'logpv_add_axy'].iloc[0], i,
  329. fontsize=gfont)
  330. elif gstyle==2:
  331. plt.annotate(i, xy=(d.loc[d[geneid] == i, lfc].iloc[0], d.loc[d[geneid] == i, 'logpv_add_axy'].iloc[0]),
  332. xycoords='data', xytext=(5, -15), textcoords='offset points', size=6,
  333. bbox=dict(boxstyle="round", alpha=0.1),
  334. arrowprops=dict(arrowstyle="wedge,tail_width=0.5", alpha=0.1, relpos=(0, 0)))
  335. else:
  336. print("Error: invalid gstyle choice")
  337. sys.exit(1)
  338. elif genenames is not None and type(genenames) is tuple:
  339. for i in d[geneid].unique():
  340. if i in genenames:
  341. if gstyle==1:
  342. plt.text(d.loc[d[geneid] == i, lfc].iloc[0], d.loc[d[geneid] == i, 'logpv_add_axy'].iloc[0], i,
  343. fontsize=gfont)
  344. elif gstyle==2:
  345. plt.annotate(i, xy=(d.loc[d[geneid] == i, lfc].iloc[0], d.loc[d[geneid] == i, 'logpv_add_axy'].iloc[0]),
  346. xycoords='data', xytext=(5, -15), textcoords='offset points', size=6,
  347. bbox=dict(boxstyle="round", alpha=0.1),
  348. arrowprops=dict(arrowstyle="wedge,tail_width=0.5", alpha=0.1, relpos=(0, 0)))
  349. else:
  350. print("Error: invalid gstyle choice")
  351. sys.exit(1)
  352. elif genenames is not None and type(genenames) is dict:
  353. for i in d[geneid].unique():
  354. if i in genenames:
  355. if gstyle==1:
  356. plt.text(d.loc[d[geneid] == i, lfc].iloc[0], d.loc[d[geneid] == i, 'logpv_add_axy'].iloc[0],
  357. genenames[i], fontsize=gfont)
  358. elif gstyle == 2:
  359. plt.annotate(genenames[i], xy=(d.loc[d[geneid] == i, lfc].iloc[0], d.loc[d[geneid] == i, 'logpv_add_axy'].iloc[0]),
  360. xycoords='data', xytext=(5, -15), textcoords='offset points', size=6,
  361. bbox=dict(boxstyle="round", alpha=0.1),
  362. arrowprops=dict(arrowstyle="wedge,tail_width=0.5", alpha=0.1, relpos=(0, 0)))
  363. else:
  364. print("Error: invalid gstyle choice")
  365. sys.exit(1)
  366. def hmap(df="dataframe", cmap="seismic", scale=True, dim=(4, 6), rowclus=True, colclus=True, zscore=None, xlabel=True,
  367. ylabel=True, tickfont=(10, 10), r=300, show=False, figtype='png', figname='heatmap', theme=None):
  368. # df = df.set_index(d.columns[0])
  369. # plot heatmap without cluster
  370. # more cmap: https://matplotlib.org/3.1.0/tutorials/colors/colormaps.html
  371. if theme == 'dark':
  372. general.dark_bg()
  373. fig, hm = plt.subplots(figsize=dim)
  374. if rowclus and colclus:
  375. hm = sns.clustermap(df, cmap=cmap, cbar=scale, z_score=zscore, xticklabels=xlabel, yticklabels=ylabel,
  376. figsize=dim)
  377. hm.ax_heatmap.set_xticklabels(hm.ax_heatmap.get_xmajorticklabels(), fontsize=tickfont[0])
  378. hm.ax_heatmap.set_yticklabels(hm.ax_heatmap.get_ymajorticklabels(), fontsize=tickfont[1])
  379. general.get_figure(show, r, figtype, figname, theme)
  380. elif rowclus and colclus is False:
  381. hm = sns.clustermap(df, cmap=cmap, cbar=scale, z_score=zscore, xticklabels=xlabel, yticklabels=ylabel,
  382. figsize=dim, row_cluster=True, col_cluster=False)
  383. hm.ax_heatmap.set_xticklabels(hm.ax_heatmap.get_xmajorticklabels(), fontsize=tickfont[0])
  384. hm.ax_heatmap.set_yticklabels(hm.ax_heatmap.get_ymajorticklabels(), fontsize=tickfont[1])
  385. general.get_figure(show, r, figtype, figname, theme)
  386. elif colclus and rowclus is False:
  387. hm = sns.clustermap(df, cmap=cmap, cbar=scale, z_score=zscore, xticklabels=xlabel, yticklabels=ylabel,
  388. figsize=dim, row_cluster=False, col_cluster=True)
  389. hm.ax_heatmap.set_xticklabels(hm.ax_heatmap.get_xmajorticklabels(), fontsize=tickfont[0])
  390. hm.ax_heatmap.set_yticklabels(hm.ax_heatmap.get_ymajorticklabels(), fontsize=tickfont[1])
  391. general.get_figure(show, r, figtype, figname, theme)
  392. else:
  393. hm = sns.heatmap(df, cmap=cmap, cbar=scale, xticklabels=xlabel, yticklabels=ylabel)
  394. plt.xticks(fontsize=tickfont[0])
  395. plt.yticks(fontsize=tickfont[1])
  396. general.get_figure(show, r, figtype, figname, theme)
  397. class General:
  398. rand_colors = ('#a7414a', '#282726', '#6a8a82', '#a37c27', '#563838', '#0584f2', '#f28a30', '#f05837',
  399. '#6465a5', '#00743f', '#be9063', '#de8cf0', '#888c46', '#c0334d', '#270101', '#8d2f23',
  400. '#ee6c81', '#65734b', '#14325c', '#704307', '#b5b3be', '#f67280', '#ffd082', '#ffd800',
  401. '#ad62aa', '#21bf73', '#a0855b', '#5edfff', '#08ffc8', '#ca3e47', '#c9753d', '#6c5ce7')
  402. def __init__(self):
  403. pass
  404. @staticmethod
  405. def get_figure(show, r, figtype, fig_name, theme, title):
  406. if title:
  407. plt.title(title)
  408. if show:
  409. plt.show()
  410. else:
  411. plt.savefig(fig_name+'.'+figtype, format=figtype, bbox_inches='tight', dpi=r)
  412. if theme:
  413. plt.style.use('default')
  414. plt.clf()
  415. plt.close()
  416. @staticmethod
  417. def axis_labels(x, y, axlabelfontsize=None, axlabelfontname=None):
  418. plt.xlabel(x, fontsize=axlabelfontsize, fontname=axlabelfontname)
  419. plt.ylabel(y, fontsize=axlabelfontsize, fontname=axlabelfontname)
  420. @staticmethod
  421. def axis_ticks(xlm=None, ylm=None, axtickfontsize=None, axtickfontname=None, ar=None):
  422. if xlm:
  423. plt.xlim(left=xlm[0], right=xlm[1])
  424. plt.xticks(np.arange(xlm[0], xlm[1], xlm[2]), fontsize=axtickfontsize, rotation=ar, fontname=axtickfontname)
  425. else:
  426. plt.xticks(fontsize=axtickfontsize, rotation=ar, fontname=axtickfontname)
  427. if ylm:
  428. plt.ylim(bottom=ylm[0], top=ylm[1])
  429. plt.yticks(np.arange(ylm[0], ylm[1], ylm[2]), fontsize=axtickfontsize, rotation=ar, fontname=axtickfontname)
  430. else:
  431. plt.yticks(fontsize=axtickfontsize, rotation=ar, fontname=axtickfontname)
  432. @staticmethod
  433. def depr_mes(func_name):
  434. print("This function is deprecated. Please use", func_name)
  435. print("Read docs at https://reneshbedre.github.io/blog/howtoinstall.html")
  436. @staticmethod
  437. def check_for_nonnumeric(pd_series=None):
  438. if pd.to_numeric(pd_series, errors='coerce').isna().sum() == 0:
  439. return 0
  440. else:
  441. return 1
  442. @staticmethod
  443. def pvalue_symbol(pv=None, symbol=None):
  444. if 0.05 >= pv > 0.01:
  445. return symbol
  446. elif 0.01 >= pv > 0.001:
  447. return 2 * symbol
  448. elif pv <= 0.001:
  449. return 3 * symbol
  450. else:
  451. return None
  452. @staticmethod
  453. def get_file_from_gd(url=None):
  454. get_path = 'https://drive.google.com/uc?export=download&id=' + url.split('/')[-2]
  455. return pd.read_csv(get_path, comment='#')
  456. @staticmethod
  457. def style_bg(theme=None):
  458. plt.style.use(theme)
  459. class general:
  460. def __init__(self):
  461. pass
  462. rand_colors = ('#a7414a', '#282726', '#6a8a82', '#a37c27', '#563838', '#0584f2', '#f28a30', '#f05837',
  463. '#6465a5', '#00743f', '#be9063', '#de8cf0', '#888c46', '#c0334d', '#270101', '#8d2f23',
  464. '#ee6c81', '#65734b', '#14325c', '#704307', '#b5b3be', '#f67280', '#ffd082', '#ffd800',
  465. '#ad62aa', '#21bf73', '#a0855b', '#5edfff', '#08ffc8', '#ca3e47', '#c9753d', '#6c5ce7')
  466. @staticmethod
  467. def get_figure(show, r, figtype, fig_name, theme):
  468. if show:
  469. plt.show()
  470. else:
  471. plt.savefig(fig_name+'.'+figtype, format=figtype, bbox_inches='tight', dpi=r)
  472. if theme == 'dark':
  473. plt.style.use('default')
  474. plt.clf()
  475. plt.close()
  476. @staticmethod
  477. def axis_labels(x, y, axlabelfontsize=None, axlabelfontname=None):
  478. plt.xlabel(x, fontsize=axlabelfontsize, fontname=axlabelfontname)
  479. plt.ylabel(y, fontsize=axlabelfontsize, fontname=axlabelfontname)
  480. # plt.xticks(fontsize=9, fontname="sans-serif")
  481. # plt.yticks(fontsize=9, fontname="sans-serif")
  482. @staticmethod
  483. def axis_ticks(xlm=None, ylm=None, axtickfontsize=None, axtickfontname=None, ar=None):
  484. if xlm:
  485. plt.xlim(left=xlm[0], right=xlm[1])
  486. plt.xticks(np.arange(xlm[0], xlm[1], xlm[2]), fontsize=axtickfontsize, rotation=ar, fontname=axtickfontname)
  487. else:
  488. plt.xticks(fontsize=axtickfontsize, rotation=ar, fontname=axtickfontname)
  489. if ylm:
  490. plt.ylim(bottom=ylm[0], top=ylm[1])
  491. plt.yticks(np.arange(ylm[0], ylm[1], ylm[2]), fontsize=axtickfontsize, rotation=ar, fontname=axtickfontname)
  492. else:
  493. plt.yticks(fontsize=axtickfontsize, rotation=ar, fontname=axtickfontname)
  494. @staticmethod
  495. def depr_mes(func_name):
  496. print("This function is deprecated. Please use", func_name )
  497. print("Read docs at https://reneshbedre.github.io/blog/howtoinstall.html")
  498. @staticmethod
  499. def check_for_nonnumeric(pd_series=None):
  500. if pd.to_numeric(pd_series, errors='coerce').isna().sum() == 0:
  501. return 0
  502. else:
  503. return 1
  504. @staticmethod
  505. def pvalue_symbol(pv=None, symbol=None):
  506. if 0.05 >= pv > 0.01:
  507. return symbol
  508. elif 0.01 >= pv > 0.001:
  509. return 2 * symbol
  510. elif pv <= 0.001:
  511. return 3 * symbol
  512. else:
  513. return None
  514. @staticmethod
  515. def get_file_from_gd(url=None):
  516. get_path = 'https://drive.google.com/uc?export=download&id=' + url.split('/')[-2]
  517. return pd.read_csv(get_path, comment='#')
  518. @staticmethod
  519. def dark_bg():
  520. plt.style.use('dark_background')
  521. class marker:
  522. def __init__(self):
  523. pass
  524. def geneplot_mhat(df, markeridcol, chr, pv, gwasp, markernames, gfont, gstyle, ax):
  525. if markeridcol is not None:
  526. if markernames is not None and markernames is True:
  527. for i in df[markeridcol].unique():
  528. if df.loc[df[markeridcol] == i, pv].iloc[0] <= gwasp:
  529. if gstyle == 1:
  530. plt.text(df.loc[df[markeridcol] == i, 'ind'].iloc[0], df.loc[df[markeridcol] == i, 'tpval'].iloc[0],
  531. str(i), fontsize=gfont)
  532. elif gstyle == 2:
  533. plt.annotate(i, xy=(df.loc[df[markeridcol] == i, 'ind'].iloc[0], df.loc[df[markeridcol] == i, 'tpval'].iloc[0]),
  534. xycoords='data', xytext=(5, -15), textcoords='offset points', size=6,
  535. bbox=dict(boxstyle="round", alpha=0.2),
  536. arrowprops=dict(arrowstyle="wedge,tail_width=0.5", alpha=0.2, relpos=(0, 0)))
  537. elif markernames is not None and isinstance(markernames, (tuple, list)):
  538. for i in df[markeridcol].unique():
  539. if i in markernames:
  540. if gstyle == 1:
  541. plt.text(df.loc[df[markeridcol] == i, 'ind'].iloc[0], df.loc[df[markeridcol] == i, 'tpval'].iloc[0],
  542. str(i), fontsize=gfont)
  543. elif gstyle == 2:
  544. plt.annotate(i, xy=(df.loc[df[markeridcol] == i, 'ind'].iloc[0], df.loc[df[markeridcol] == i, 'tpval'].iloc[0]),
  545. xycoords='data', xytext=(5, -15), textcoords='offset points', size=6,
  546. bbox=dict(boxstyle="round", alpha=0.2),
  547. arrowprops=dict(arrowstyle="wedge,tail_width=0.5", alpha=0.2, relpos=(0, 0)))
  548. elif markernames is not None and isinstance(markernames, dict):
  549. for i in df[markeridcol].unique():
  550. if i in markernames:
  551. if gstyle == 1:
  552. plt.text(df.loc[df[markeridcol] == i, 'ind'].iloc[0], df.loc[df[markeridcol] == i, 'tpval'].iloc[0],
  553. markernames[i], fontsize=gfont)
  554. elif gstyle == 2:
  555. plt.annotate(markernames[i], xy=(
  556. df.loc[df[markeridcol] == i, 'ind'].iloc[0], df.loc[df[markeridcol] == i, 'tpval'].iloc[0]),
  557. xycoords='data', xytext=(5, -15), textcoords='offset points', size=6,
  558. bbox=dict(boxstyle="round", alpha=0.2),
  559. arrowprops=dict(arrowstyle="wedge,tail_width=0.5", alpha=0.2, relpos=(0, 0)))
  560. else:
  561. raise Exception("provide 'markeridcol' parameter")
  562. def mhat(df="dataframe", chr=None, pv=None, log_scale=True, color=None, dim=(6,4), r=300, ar=90, gwas_sign_line=False,
  563. gwasp=5E-08, dotsize=8, markeridcol=None, markernames=None, gfont=8, valpha=1, show=False, figtype='png',
  564. axxlabel=None, axylabel=None, axlabelfontsize=9, axlabelfontname="Arial", axtickfontsize=9,
  565. axtickfontname="Arial", ylm=None, gstyle=1, figname='manhattan', theme=None):
  566. _x, _y = 'Chromosomes', r'$ -log_{10}(P)$'
  567. rand_colors = ('#a7414a', '#282726', '#6a8a82', '#a37c27', '#563838', '#0584f2', '#f28a30', '#f05837',
  568. '#6465a5', '#00743f', '#be9063', '#de8cf0', '#888c46', '#c0334d', '#270101', '#8d2f23',
  569. '#ee6c81', '#65734b', '#14325c', '#704307', '#b5b3be', '#f67280', '#ffd082', '#ffd800',
  570. '#ad62aa', '#21bf73', '#a0855b', '#5edfff', '#08ffc8', '#ca3e47', '#c9753d', '#6c5ce7',
  571. '#a997df', '#513b56', '#590925', '#007fff', '#bf1363', '#f39237', '#0a3200', '#8c271e')
  572. if log_scale:
  573. # minus log10 of P-value
  574. df['tpval'] = -np.log10(df[pv])
  575. else:
  576. # for Fst values
  577. df['tpval'] = df[pv]
  578. # df = df.sort_values(chr)
  579. # if the column contains numeric strings
  580. df = df.loc[pd.to_numeric(df[chr], errors='coerce').sort_values().index]
  581. # add indices
  582. df['ind'] = range(len(df))
  583. df_group = df.groupby(chr)
  584. if color is not None and len(color) == 2:
  585. color_1 = int(df[chr].nunique() / 2) * [color[0]]
  586. color_2 = int(df[chr].nunique() / 2) * [color[1]]
  587. if df[chr].nunique() % 2 == 0:
  588. color_list = list(reduce(lambda x, y: x+y, zip(color_1, color_2)))
  589. elif df[chr].nunique() % 2 == 1:
  590. color_list = list(reduce(lambda x, y: x+y, zip(color_1, color_2)))
  591. color_list.append(color[0])
  592. elif color is not None and len(color) == df[chr].nunique():
  593. color_list = color
  594. elif color is None:
  595. # select colors randomly from the list based in number of chr
  596. color_list = sample(rand_colors, df[chr].nunique())
  597. else:
  598. print("Error: in color argument")
  599. sys.exit(1)
  600. xlabels = []
  601. xticks = []
  602. if theme == 'dark':
  603. general.dark_bg()
  604. fig, ax = plt.subplots(figsize=dim)
  605. i = 0
  606. for label, df1 in df.groupby(chr):
  607. df1.plot(kind='scatter', x='ind', y='tpval', color=color_list[i], s=dotsize, alpha=valpha, ax=ax)
  608. df1_max_ind = df1['ind'].iloc[-1]
  609. df1_min_ind = df1['ind'].iloc[0]
  610. xlabels.append(label)
  611. xticks.append((df1_max_ind - (df1_max_ind - df1_min_ind) / 2))
  612. i += 1
  613. # add GWAS significant line
  614. if gwas_sign_line is True:
  615. ax.axhline(y=-np.log10(gwasp), linestyle='--', color='#7d7d7d', linewidth=1)
  616. if markernames is not None:
  617. marker.geneplot_mhat(df, markeridcol, chr, pv, gwasp, markernames, gfont, gstyle, ax=ax)
  618. ax.margins(x=0)
  619. ax.margins(y=0)
  620. ax.set_xticks(xticks)
  621. if log_scale:
  622. ax.set_ylim([0, max(df['tpval'] + 1)])
  623. if ylm:
  624. ylm = np.arange(ylm[0], ylm[1], ylm[2])
  625. else:
  626. ylm = np.arange(0, max(df['tpval']+1), 1)
  627. ax.set_yticks(ylm)
  628. ax.set_xticklabels(xlabels, rotation=ar)
  629. # ax.set_yticklabels(ylm, fontsize=axtickfontsize, fontname=axtickfontname, rotation=ar)
  630. if axxlabel:
  631. _x = axxlabel
  632. if axylabel:
  633. _y = axylabel
  634. ax.set_xlabel(_x, fontsize=axlabelfontsize, fontname=axlabelfontname)
  635. ax.set_ylabel(_y, fontsize=axlabelfontsize, fontname=axlabelfontname)
  636. general.get_figure(show, r, figtype, figname, theme)
  637. class Statis:
  638. def __init__(self):
  639. pass
  640. @staticmethod
  641. def count_plot(df='dataframe', factor=None, dim=(6, 4)):
  642. # set axis labels to None
  643. _x = None
  644. _y = None
  645. get_factors = df['disease'].value_counts().index
  646. xbar = np.arange(len(get_factors))
  647. get_factors_counts = df['disease'].value_counts()
  648. class stat:
  649. def __init__(self):
  650. pass
  651. def bardot(df="dataframe", dim=(6, 4), bw=0.4, colorbar="#f2aa4cff", colordot=["#101820ff"], hbsize=4, r=300, ar=0,
  652. dotsize=6, valphabar=1, valphadot=1, markerdot="o", errorbar=True, show=False, ylm=None, axtickfontsize=9,
  653. axtickfontname="Arial", axlabelfontsize=9, axlabelfontname="Arial", yerrlw=None, yerrcw=None, axxlabel=None,
  654. axylabel=None, figtype='png'):
  655. # set axis labels to None
  656. _x = None
  657. _y = None
  658. xbar = np.arange(len(df.columns.to_numpy()))
  659. color_list_bar = colorbar
  660. color_list_dot = colordot
  661. if len(color_list_dot) == 1:
  662. color_list_dot = colordot*len(df.columns.to_numpy())
  663. if theme == 'dark':
  664. general.dark_bg()
  665. plt.subplots(figsize=dim)
  666. if errorbar:
  667. plt.bar(x=xbar, height=df.describe().loc['mean'], yerr=df.sem(), width=bw, color=color_list_bar, capsize=hbsize,
  668. zorder=0, alpha=valphabar, error_kw={'elinewidth': yerrlw, 'capthick': yerrcw})
  669. else:
  670. plt.bar(x=xbar, height=df.describe().loc['mean'], width=bw, color=color_list_bar,
  671. capsize=hbsize,
  672. zorder=0, alpha=valphabar)
  673. plt.xticks(xbar, df.columns.to_numpy(), fontsize=axtickfontsize, rotation=ar, fontname=axtickfontname)
  674. if axxlabel:
  675. _x = axxlabel
  676. if axylabel:
  677. _y = axylabel
  678. general.axis_labels(_x, _y, axlabelfontsize, axlabelfontname)
  679. # ylm must be tuple of start, end, interval
  680. if ylm:
  681. plt.ylim(bottom=ylm[0], top=ylm[1])
  682. plt.yticks(np.arange(ylm[0], ylm[1], ylm[2]), fontsize=axtickfontsize, fontname=axtickfontname)
  683. plt.yticks(fontsize=axtickfontsize, rotation=ar, fontname=axtickfontname)
  684. # add dots
  685. for cols in range(len(df.columns.to_numpy())):
  686. # get markers from here https://matplotlib.org/3.1.1/api/markers_api.html
  687. plt.scatter(x=np.linspace(xbar[cols]-bw/2, xbar[cols]+bw/2, int(df.describe().loc['count'][cols])),
  688. y=df[df.columns[cols]].dropna(), s=dotsize, color=color_list_dot[cols], zorder=1, alpha=valphadot,
  689. marker=markerdot)
  690. general.get_figure(show, r, figtype, 'bardot', theme)
  691. def regplot(df="dataframe", x=None, y=None, yhat=None, dim=(6, 4), colordot='#4a4e4d', colorline='#fe8a71', r=300,
  692. ar=0, dotsize=6, valphaline=1, valphadot=1, linewidth=1, markerdot="o", show=False, axtickfontsize=9,
  693. axtickfontname="Arial", axlabelfontsize=9, axlabelfontname="Arial", ylm=None, xlm=None, axxlabel=None,
  694. axylabel=None, figtype='png', theme=None):
  695. if theme == 'dark':
  696. general.dark_bg()
  697. fig, ax = plt.subplots(figsize=dim)
  698. plt.scatter(df[x].to_numpy(), df[y].to_numpy(), color=colordot, s=dotsize, alpha=valphadot, marker=markerdot,
  699. label='Observed data')
  700. plt.plot(df[x].to_numpy(), df[yhat].to_numpy(), color=colorline, linewidth=linewidth, alpha=valphaline,
  701. label='Regression line')
  702. if axxlabel:
  703. x = axxlabel
  704. if axylabel:
  705. y = axylabel
  706. general.axis_labels(x, y, axlabelfontsize, axlabelfontname)
  707. general.axis_ticks(xlm, ylm, axtickfontsize, axtickfontname, ar)
  708. plt.legend(fontsize=9)
  709. general.get_figure(show, r, figtype, 'reg_plot', theme)
  710. def reg_resid_plot(df="dataframe", yhat=None, resid=None, stdresid=None, dim=(6, 4), colordot='#4a4e4d',
  711. colorline='#2ab7ca', r=300, ar=0, dotsize=6, valphaline=1, valphadot=1, linewidth=1,
  712. markerdot="o", show=False, figtype='png', theme=None):
  713. if theme == 'dark':
  714. general.dark_bg()
  715. fig, ax = plt.subplots(figsize=dim)
  716. if resid is not None:
  717. plt.scatter(df[yhat], df[resid], color=colordot, s=dotsize, alpha=valphadot, marker=markerdot)
  718. plt.axhline(y=0, color=colorline, linestyle='--', linewidth=linewidth, alpha=valphaline)
  719. plt.xlabel("Fitted")
  720. plt.ylabel("Residuals")
  721. general.get_figure(show, r, figtype, 'resid_plot', theme)
  722. else:
  723. print ("Error: Provide residual data")
  724. if stdresid is not None:
  725. plt.scatter(df[yhat], df[stdresid], color=colordot, s=dotsize, alpha=valphadot, marker=markerdot)
  726. plt.axhline(y=0, color=colorline, linestyle='--', linewidth=linewidth, alpha=valphaline)
  727. plt.xlabel("Fitted")
  728. plt.ylabel("Standardized Residuals")
  729. general.get_figure(show, r, figtype, 'std_resid_plot', theme)
  730. else:
  731. print ("Error: Provide standardized residual data")
  732. def corr_mat(df="dataframe", corm="pearson", cmap="seismic", r=300, show=False, dim=(6, 5), axtickfontname="Arial",
  733. axtickfontsize=7, ar=90, figtype='png', theme=None):
  734. if theme == 'dark':
  735. general.dark_bg()
  736. d_corr = df.corr(method=corm)
  737. plt.subplots(figsize=dim)
  738. plt.matshow(d_corr, vmin=-1, vmax=1, cmap=cmap)
  739. plt.colorbar()
  740. cols = list(df)
  741. ticks = list(range(0, len(list(df))))
  742. plt.xticks(ticks, cols, fontsize=axtickfontsize, fontname=axtickfontname, rotation=ar)
  743. plt.yticks(ticks, cols, fontsize=axtickfontsize, fontname=axtickfontname)
  744. general.get_figure(show, r, figtype, 'corr_mat', theme)
  745. # for data with pre-calculated mean and SE
  746. def multi_bar(df="dataframe", dim=(5, 4), colbar=None, colerrorbar=None, bw=0.4, colorbar=None, xbarcol=None, r=300,
  747. show=False, axtickfontname="Arial", axtickfontsize=9, ax_x_ticklabel=None, ar=90, figtype='png',
  748. figname='multi_bar', valphabar=1, legendpos='best', errorbar=False, yerrlw=None, yerrcw=None,
  749. plotlegend=False, hbsize=4, ylm=None, add_sign_line=False, pv=None,
  750. sign_line_opts={'symbol': '*', 'fontsize': 8, 'linewidth':0.8, 'arrowstyle': '-', 'dist_y_pos': 2.5,
  751. 'dist_y_neg': 4.2}, add_sign_symbol=False, sign_symbol_opts={'symbol': '*',
  752. 'fontsize': 8 },
  753. dotplot=False, sub_cat=None,
  754. sub_cat_opts={'y_neg_dist': 3.5, 'fontsize': 8}, sub_cat_label_dist=None, theme=None):
  755. xbar = np.arange(df.shape[0])
  756. xbar_temp = xbar
  757. if theme == 'dark':
  758. general.dark_bg()
  759. fig, ax = plt.subplots(figsize=dim)
  760. assert len(colbar) >= 2, "number of bar should be atleast 2"
  761. assert len(colbar) == len(colorbar), "number of color should be equivalent to number of column bars"
  762. if colbar is not None and isinstance(colbar, (tuple, list)):
  763. for i in range(len(colbar)):
  764. if errorbar:
  765. ax.bar(x=xbar_temp, height=df[colbar[i]], yerr=df[colerrorbar[i]], width=bw, color=colorbar[i],
  766. alpha=valphabar, capsize=hbsize, label=colbar[i], error_kw={'elinewidth': yerrlw,
  767. 'capthick': yerrcw})
  768. xbar_temp = xbar_temp+bw
  769. else:
  770. ax.bar(x=xbar_temp, height=df[colbar[i]], width=bw, color=colorbar[i], alpha=valphabar,
  771. label=colbar[i])
  772. xbar_temp = xbar_temp + bw
  773. ax.set_xticks(xbar+( (bw*(len(colbar)-1)) / (1+(len(colbar)-1)) ))
  774. if ax_x_ticklabel:
  775. x_ticklabel = ax_x_ticklabel
  776. else:
  777. x_ticklabel = df[xbarcol]
  778. ax.set_xticklabels(x_ticklabel, fontsize=axtickfontsize, rotation=ar, fontname=axtickfontname)
  779. # ylm must be tuple of start, end, interval
  780. if ylm:
  781. plt.ylim(bottom=ylm[0], top=ylm[1])
  782. plt.yticks(np.arange(ylm[0], ylm[1], ylm[2]), fontsize=axtickfontsize, fontname=axtickfontname)
  783. if plotlegend:
  784. plt.legend(loc=legendpos)
  785. if dotplot:
  786. for cols in range(len(df2['factors'].unique())):
  787. ax.scatter(x=np.linspace(xbar[cols] - bw / 2, xbar[cols] + bw / 2, int(reps)),
  788. y=df2[(df2['factors'] == df2['factors'].unique()[cols]) & (df2['sample'] == 'M')]['value'],
  789. s=dotsize, color="#7d0013", zorder=1, alpha=valphadot,
  790. marker=markerdot)
  791. if add_sign_line:
  792. if len(colbar) == 2:
  793. for i in xbar:
  794. x_pos = xbar[i]
  795. x_pos_2 = xbar[i] + bw
  796. y_pos = df[colbar[0]].to_numpy()[i] + df[colerrorbar[0]].to_numpy()[i]
  797. y_pos_2 = df[colbar[1]].to_numpy()[i] + df[colerrorbar[1]].to_numpy()[i]
  798. # only if y axis is positive
  799. if y_pos > 0:
  800. y_pos += 0.5
  801. y_pos_2 += 0.5
  802. pv_symb = general.pvalue_symbol(pv[i], sign_line_opts['symbol'])
  803. if pv_symb:
  804. ax.annotate('', xy=(x_pos, y_pos), xytext=(x_pos_2, y_pos),
  805. arrowprops={'connectionstyle': 'bar, armA=50, armB=50, angle=180, fraction=0 ',
  806. 'arrowstyle': sign_line_opts['arrowstyle'],
  807. 'linewidth': sign_line_opts['linewidth']})
  808. ax.annotate(pv_symb, xy=(np.mean([x_pos, x_pos_2]), max(y_pos, y_pos_2) +
  809. sign_line_opts['dist_y_pos']),
  810. fontsize=sign_line_opts['fontsize'], ha="center")
  811. else:
  812. y_pos -= 0.5
  813. y_pos_2 -= 0.5
  814. pv_symb = general.pvalue_symbol(pv[i], sign_line_opts['symbol'])
  815. if pv_symb:
  816. ax.annotate('', xy=(x_pos, y_pos), xytext=(x_pos_2, y_pos),
  817. arrowprops={'connectionstyle': 'bar, armA=50, armB=50, angle=180, fraction=-1 ',
  818. 'arrowstyle': sign_line_opts['arrowstyle'],
  819. 'linewidth': sign_line_opts['linewidth']})
  820. ax.annotate(pv_symb, xy=(np.mean([x_pos, x_pos_2]), min(y_pos_2, y_pos) -
  821. sign_line_opts['dist_y_neg']),
  822. fontsize=sign_line_opts['fontsize'], ha="center")
  823. if add_sign_symbol:
  824. if len(colbar) == 2:
  825. for i in xbar:
  826. x_pos = xbar[i]
  827. x_pos_2 = xbar[i] + bw
  828. # max value size factor is essential for rel pos of symbol
  829. y_pos = df[colbar[0]].to_numpy()[i] + df[colerrorbar[0]].to_numpy()[i] + \
  830. (max(df[colbar[0]].to_numpy()) / 20)
  831. y_pos_2 = df[colbar[1]].to_numpy()[i] + df[colerrorbar[1]].to_numpy()[i] + \
  832. (max(df[colbar[1]].to_numpy()) / 20)
  833. # only if y axis is positive
  834. if y_pos > 0:
  835. pv_symb_1 = general.pvalue_symbol(pv[i][0], sign_symbol_opts['symbol'])
  836. pv_symb_2 = general.pvalue_symbol(pv[i][1], sign_symbol_opts['symbol'])
  837. if pv_symb_1:
  838. plt.annotate(pv_symb_1, xy=(x_pos, y_pos), fontsize=sign_symbol_opts['fontsize'],
  839. ha="center")
  840. if pv_symb_2:
  841. plt.annotate(pv_symb_2, xy=(x_pos_2, y_pos_2), fontsize=sign_symbol_opts['fontsize'],
  842. ha="center")
  843. elif len(colbar) == 3:
  844. for i in xbar:
  845. x_pos = xbar[i]
  846. x_pos_2 = xbar[i] + bw
  847. x_pos_3 = xbar[i] + (2 * bw)
  848. # max value size factor is essential for rel pos of symbol
  849. y_pos = df[colbar[0]].to_numpy()[i] + df[colerrorbar[0]].to_numpy()[i] + \
  850. (max(df[colbar[0]].to_numpy()) / 20)
  851. y_pos_2 = df[colbar[1]].to_numpy()[i] + df[colerrorbar[1]].to_numpy()[i] + \
  852. (max(df[colbar[1]].to_numpy()) / 20)
  853. y_pos_3 = df[colbar[2]].to_numpy()[i] + df[colerrorbar[2]].to_numpy()[i] + \
  854. (max(df[colbar[2]].to_numpy()) / 20)
  855. # only if y axis is positive
  856. if y_pos > 0:
  857. pv_symb_1 = general.pvalue_symbol(pv[i][0], sign_symbol_opts['symbol'])
  858. pv_symb_2 = general.pvalue_symbol(pv[i][1], sign_symbol_opts['symbol'])
  859. pv_symb_3 = general.pvalue_symbol(pv[i][2], sign_symbol_opts['symbol'])
  860. if pv_symb_1:
  861. plt.annotate(pv_symb_1, xy=(x_pos, y_pos), fontsize=sign_symbol_opts['fontsize'],
  862. ha="center")
  863. if pv_symb_2:
  864. plt.annotate(pv_symb_2, xy=(x_pos_2, y_pos_2), fontsize=sign_symbol_opts['fontsize'],
  865. ha="center")
  866. if pv_symb_3:
  867. plt.annotate(pv_symb_3, xy=(x_pos_3, y_pos_3), fontsize=sign_symbol_opts['fontsize'],
  868. ha="center")
  869. # update this later for min_value
  870. min_value = 0
  871. sub_cat_i = 0
  872. if sub_cat:
  873. if isinstance(sub_cat, dict):
  874. for k in sub_cat:
  875. if isinstance(k, tuple) and len(k) == 2:
  876. cat_x_pos, cat_y_pos, cat_x_pos_2 = k[0], min_value - \
  877. (sub_cat_opts[
  878. 'y_neg_dist'] * size_factor_to_start_line), k[1]
  879. plt.annotate('', xy=(cat_x_pos - (bw / 2), cat_y_pos),
  880. xytext=(cat_x_pos_2 + (bw / 2), cat_y_pos),
  881. arrowprops={'arrowstyle': '-', 'linewidth': 0.5}, annotation_clip=False)
  882. if sub_cat_label_dist and isinstance(sub_cat_label_dist, list):
  883. plt.annotate(sub_cat[k], xy=(np.mean([cat_x_pos, cat_x_pos_2]),
  884. cat_y_pos - size_factor_to_start_line - sub_cat_label_dist[
  885. sub_cat_i]),
  886. ha="center", fontsize=sub_cat_opts['fontsize'], annotation_clip=False)
  887. sub_cat_i += 1
  888. else:
  889. plt.annotate(sub_cat[k], xy=(np.mean([cat_x_pos, cat_x_pos_2]),
  890. cat_y_pos - size_factor_to_start_line),
  891. ha="center", fontsize=sub_cat_opts['fontsize'], annotation_clip=False)
  892. else:
  893. raise KeyError("Sub category keys must be tuple of size 2")
  894. general.get_figure(show, r, figtype, figname, theme)
  895. # with replicates values stacked replicates
  896. # need to work on this later
  897. def multi_bar_raw(df="dataframe", dim=(5, 4), samp_col_name=None, bw=0.4, colorbar=None, r=300,
  898. show=False, axtickfontname="Arial", axtickfontsize=(9, 9), ax_x_ticklabel=None, ar=(0, 90), figtype='png',
  899. figname='multi_bar', valphabar=1, legendpos='best', errorbar=False, yerrlw=None, yerrcw=None,
  900. plotlegend=False, hbsize=4, ylm=None, add_sign_line=False, pv=None,
  901. sign_line_opts={'symbol': '*', 'fontsize': 9, 'linewidth': 0.8, 'arrowstyle': '-', 'dist_y_pos': 2.5,
  902. 'dist_y_neg': 4.2}, add_sign_symbol=False,
  903. sign_symbol_opts={'symbol': '*', 'fontsize': 9, 'fontname':'Arial', 'rotation':0},
  904. dotplot=False, dotplot_opts={'dotsize': 5, 'color':'#7d0013', 'valpha': 1, 'marker': 'o'},
  905. sign_line_pairs=None, group_let_df=None, legendanchor=None, legendcols=1, legendfontsize=8,
  906. axylabel=None, axxlabel=None, symb_dist=None, axlabelfontsize=(9, 9), axlabelar=(0, 90), sub_cat=None,
  907. sub_cat_opts={'y_neg_dist': 3.5, 'fontsize': 9, 'fontname':'Arial'}, sub_cat_label_dist=None,
  908. legendlabelframe=False, div_fact=20, legend_columnspacing=None, add_text=None, theme=None):
  909. if samp_col_name is None or colorbar is None:
  910. raise ValueError('Invalid value for samp_col_name or colorbar options')
  911. if theme == 'dark':
  912. general.dark_bg()
  913. fig, ax = plt.subplots(figsize=dim)
  914. sample_list = df[samp_col_name].unique()
  915. # assert len(sample_list) >= 2, "number of bar should be atleast 2"
  916. df_mean = df.groupby(samp_col_name).mean().reset_index().set_index(samp_col_name).T
  917. df_sem = df.groupby(samp_col_name).sem().reset_index().set_index(samp_col_name).T
  918. colbar = sample_list
  919. colerrorbar = sample_list
  920. xbar = np.arange(df_mean.shape[0])
  921. xbar_temp = xbar
  922. xbarcol = df_mean.index
  923. assert len(colbar) == len(colorbar), "number of color should be equivalent to number of column bars"
  924. df_melt = pd.melt(df.reset_index(), id_vars=[samp_col_name], value_vars=df_mean.index)
  925. variable_list = df_melt['variable'].unique()
  926. min_value = (0, min(df_mean.min()))[min(df_mean.min()) < 0]
  927. if colbar is not None:
  928. for i in range(len(colbar)):
  929. if errorbar:
  930. ax.bar(x=xbar_temp, height=df_mean[colbar[i]], yerr=df_sem[colerrorbar[i]], width=bw,
  931. color=colorbar[i], alpha=valphabar, capsize=hbsize, label=colbar[i],
  932. error_kw={'elinewidth': yerrlw, 'capthick': yerrcw})
  933. xbar_temp = xbar_temp + bw
  934. else:
  935. ax.bar(x=xbar_temp, height=df_mean[colbar[i]], width=bw, color=colorbar[i], alpha=valphabar,
  936. label=colbar[i])
  937. xbar_temp = xbar_temp + bw
  938. bw_fact = bw / 2
  939. ax.set_xticks(xbar+((len(df_mean.columns)-1) * bw_fact) )
  940. # ax.set_xticks(xbar + ((bw * (len(colbar) - 1)) / (1 + (len(colbar) - 1))))
  941. if ax_x_ticklabel:
  942. x_ticklabel = ax_x_ticklabel
  943. else:
  944. x_ticklabel = df[xbarcol]
  945. ax.set_xticklabels(x_ticklabel, fontsize=axtickfontsize[0], rotation=ar[0], fontname=axtickfontname)
  946. if axylabel:
  947. ax.set_ylabel(axylabel, fontsize=axlabelfontsize[1], rotation=axlabelar[1], fontname=axtickfontname)
  948. if axxlabel:
  949. ax.set_xlabel(axxlabel, fontsize=axlabelfontsize[0], rotation=axlabelar[0], fontname=axtickfontname)
  950. # ylm must be tuple of start, end, interval
  951. if ylm:
  952. plt.ylim(bottom=ylm[0], top=ylm[1])
  953. plt.yticks(np.arange(ylm[0], ylm[1], ylm[2]), fontsize=axtickfontsize[1],
  954. fontname=axtickfontname)
  955. if plotlegend:
  956. plt.legend(loc=legendpos, bbox_to_anchor=legendanchor, ncol=legendcols, fontsize=legendfontsize,
  957. frameon=legendlabelframe, columnspacing=legend_columnspacing)
  958. if isinstance(add_text, list):
  959. plt.text(add_text[0], add_text[1], add_text[2], fontsize=9, fontfamily='Arial')
  960. if dotplot:
  961. for cols in range(len(variable_list)):
  962. move_fact = 0
  963. for cols1 in range(len(sample_list)):
  964. ax.scatter(x=np.linspace(xbar[cols] - bw_fact + move_fact, xbar[cols] + bw_fact + move_fact,
  965. int(df.groupby(samp_col_name).count().loc[sample_list[cols1], variable_list[cols]])),
  966. y=df_melt[(df_melt['variable'] == df_melt['variable'].unique()[cols]) & (
  967. df_melt[samp_col_name] == sample_list[cols1])]['value'], s=dotplot_opts['dotsize'],
  968. color=dotplot_opts['color'], zorder=10, alpha=dotplot_opts['valpha'],
  969. marker=dotplot_opts['marker'])
  970. move_fact += 2 * bw_fact
  971. size_factor_to_start_line = max(df_mean.max()) / div_fact
  972. y_pos_dict = dict()
  973. y_pos_dict_trt = dict()
  974. if add_sign_line:
  975. if len(colbar) == 2:
  976. for i in xbar:
  977. x_pos = xbar[i]
  978. x_pos_2 = xbar[i] + bw
  979. y_pos = df_mean[colbar[0]].to_numpy()[i] + df_sem[colerrorbar[0]].to_numpy()[i]
  980. y_pos_2 = df_mean[colbar[1]].to_numpy()[i] + df_sem[colerrorbar[1]].to_numpy()[i]
  981. # only if y axis is positive
  982. if y_pos > 0:
  983. y_pos += 0.5
  984. y_pos_2 += 0.5
  985. pv_symb = general.pvalue_symbol(pv[i], sign_line_opts['symbol'])
  986. if pv_symb:
  987. ax.annotate('', xy=(x_pos, max(y_pos, y_pos_2)), xytext=(x_pos_2, max(y_pos, y_pos_2)),
  988. arrowprops={'connectionstyle': 'bar, armA=50, armB=50, angle=180, fraction=0 ',
  989. 'arrowstyle': sign_line_opts['arrowstyle'],
  990. 'linewidth': sign_line_opts['linewidth']})
  991. ax.annotate(pv_symb, xy=(np.mean([x_pos, x_pos_2]), max(y_pos, y_pos_2) +
  992. sign_line_opts['dist_y_pos']),
  993. fontsize=sign_line_opts['fontsize'], ha="center")
  994. else:
  995. y_pos -= 0.5
  996. y_pos_2 -= 0.5
  997. pv_symb = general.pvalue_symbol(pv[i], sign_line_opts['symbol'])
  998. if pv_symb:
  999. ax.annotate('', xy=(x_pos, y_pos), xytext=(x_pos_2, y_pos),
  1000. arrowprops={'connectionstyle': 'bar, armA=50, armB=50, angle=180, fraction=-1 ',
  1001. 'arrowstyle': sign_line_opts['arrowstyle'],
  1002. 'linewidth': sign_line_opts['linewidth']})
  1003. ax.annotate(pv_symb, xy=(np.mean([x_pos, x_pos_2]), min(y_pos_2, y_pos) -
  1004. sign_line_opts['dist_y_neg']),
  1005. fontsize=sign_line_opts['fontsize'], ha="center")
  1006. elif len(colbar) == 3:
  1007. for i in xbar:
  1008. x_pos = xbar[i]
  1009. x_pos_2 = xbar[i] + bw
  1010. x_pos_3 = xbar[i] + (2 * bw)
  1011. y_pos = df_mean[colbar[0]].to_numpy()[i] + df_sem[colerrorbar[0]].to_numpy()[i]
  1012. y_pos_2 = df_mean[colbar[1]].to_numpy()[i] + df_sem[colerrorbar[1]].to_numpy()[i]
  1013. y_pos_3 = df_mean[colbar[2]].to_numpy()[i] + df_sem[colerrorbar[2]].to_numpy()[i]
  1014. # only if y axis is positive
  1015. if y_pos > 0:
  1016. y_pos += size_factor_to_start_line / 2
  1017. y_pos_2 += size_factor_to_start_line / 2
  1018. y_pos_3 += size_factor_to_start_line / 2
  1019. pv_symb1 = general.pvalue_symbol(pv[i][0], sign_line_opts['symbol'])
  1020. pv_symb2 = general.pvalue_symbol(pv[i][1], sign_line_opts['symbol'])
  1021. if pv_symb1:
  1022. if max(y_pos, y_pos_2) >= y_pos_3:
  1023. pass
  1024. ax.annotate('', xy=(x_pos, max(y_pos, y_pos_2)), xytext=(x_pos_2, max(y_pos, y_pos_2)),
  1025. arrowprops={'connectionstyle': 'bar, armA=50, armB=50, angle=180, fraction=0 ',
  1026. 'arrowstyle': sign_line_opts['arrowstyle'],
  1027. 'linewidth': sign_line_opts['linewidth']})
  1028. ax.annotate(pv_symb1, xy=(np.mean([x_pos, x_pos_2]), max(y_pos, y_pos_2) +
  1029. size_factor_to_start_line),
  1030. fontsize=sign_line_opts['fontsize'], ha="center")
  1031. if pv_symb2:
  1032. if max(y_pos, y_pos_3) < y_pos_2:
  1033. y_pos_3 = y_pos_2 + (4 * size_factor_to_start_line)
  1034. ax.annotate('', xy=(x_pos, max(y_pos, y_pos_3)), xytext=(x_pos_3, max(y_pos, y_pos_3)),
  1035. arrowprops={'connectionstyle': 'bar, armA=50, armB=50, angle=180, fraction=0 ',
  1036. 'arrowstyle': sign_line_opts['arrowstyle'],
  1037. 'linewidth': sign_line_opts['linewidth']})
  1038. ax.annotate(pv_symb2, xy=(np.mean([x_pos, x_pos_3]), max(y_pos, y_pos_3) +
  1039. size_factor_to_start_line),
  1040. fontsize=sign_line_opts['fontsize'], ha="center")
  1041. else:
  1042. y_pos -= 0.5
  1043. y_pos_2 -= 0.5
  1044. pv_symb = general.pvalue_symbol(pv[i], sign_line_opts['symbol'])
  1045. if pv_symb:
  1046. ax.annotate('', xy=(x_pos, y_pos), xytext=(x_pos_2, y_pos),
  1047. arrowprops={'connectionstyle': 'bar, armA=50, armB=50, angle=180, fraction=-1 ',
  1048. 'arrowstyle': sign_line_opts['arrowstyle'],
  1049. 'linewidth': sign_line_opts['linewidth']})
  1050. ax.annotate(pv_symb, xy=(np.mean([x_pos, x_pos_2]), min(y_pos_2, y_pos) -
  1051. sign_line_opts['dist_y_neg']),
  1052. fontsize=sign_line_opts['fontsize'], ha="center")
  1053. if add_sign_symbol:
  1054. if len(colbar) == 2:
  1055. for i in xbar:
  1056. x_pos = xbar[i]
  1057. x_pos_2 = xbar[i] + bw
  1058. if symb_dist:
  1059. # max value size factor is essential for rel pos of symbol
  1060. y_pos = df_mean[colbar[0]].to_numpy()[i] + df_sem[colerrorbar[0]].to_numpy()[i] + \
  1061. (max(df_mean[colbar[0]].to_numpy()) / 20) + symb_dist[i][0]
  1062. y_pos_2 = df_mean[colbar[1]].to_numpy()[i] + df_sem[colerrorbar[1]].to_numpy()[i] + \
  1063. (max(df_mean[colbar[1]].to_numpy()) / 20) + symb_dist[i][1]
  1064. else:
  1065. y_pos = df_mean[colbar[0]].to_numpy()[i] + df_sem[colerrorbar[0]].to_numpy()[i] + \
  1066. (max(df_mean[colbar[0]].to_numpy()) / 20)
  1067. y_pos_2 = df_mean[colbar[1]].to_numpy()[i] + df_sem[colerrorbar[1]].to_numpy()[i] + \
  1068. (max(df_mean[colbar[1]].to_numpy()) / 20)
  1069. '''
  1070. y_pos = df[colbar[0]].to_numpy()[i] + df[colerrorbar[0]].to_numpy()[i] + \
  1071. (max(df[colbar[0]].to_numpy()) / 20)
  1072. y_pos_2 = df[colbar[1]].to_numpy()[i] + df[colerrorbar[1]].to_numpy()[i] + \
  1073. (max(df[colbar[1]].to_numpy()) / 20)
  1074. '''
  1075. # group_let_df need index column
  1076. if isinstance(group_let_df, pd.DataFrame):
  1077. # only if y axis is positive
  1078. if y_pos > 0:
  1079. if not pd.isnull(group_let_df.loc[colbar[0], xbarcol[i]]):
  1080. plt.annotate(group_let_df.loc[colbar[0], xbarcol[i]], xy=(x_pos, y_pos),
  1081. fontsize=sign_symbol_opts['fontsize'], ha='center',
  1082. fontfamily=sign_symbol_opts['fontname'],
  1083. rotation=sign_symbol_opts['rotation'])
  1084. if y_pos_2 > 0:
  1085. if not pd.isnull(group_let_df.loc[colbar[1], xbarcol[i]]):
  1086. plt.annotate(group_let_df.loc[colbar[1], xbarcol[i]], xy=(x_pos_2, y_pos_2),
  1087. fontsize=sign_symbol_opts['fontsize'], ha='center',
  1088. fontfamily=sign_symbol_opts['fontname'],
  1089. rotation=sign_symbol_opts['rotation'])
  1090. # only if y axis is positive
  1091. # need to verify this
  1092. elif pv:
  1093. if y_pos > 0:
  1094. pv_symb_1 = general.pvalue_symbol(pv[i][0], sign_symbol_opts['symbol'])
  1095. pv_symb_2 = general.pvalue_symbol(pv[i][1], sign_symbol_opts['symbol'])
  1096. if pv_symb_1:
  1097. plt.annotate(pv_symb_1, xy=(x_pos, y_pos), fontsize=sign_symbol_opts['fontsize'],
  1098. ha="center", fontfamily=sign_symbol_opts['fontname'],
  1099. rotation=sign_symbol_opts['rotation'])
  1100. if pv_symb_2:
  1101. plt.annotate(pv_symb_2, xy=(x_pos_2, y_pos_2), fontsize=sign_symbol_opts['fontsize'],
  1102. ha="center", fontfamily=sign_symbol_opts['fontname'],
  1103. rotation=sign_symbol_opts['rotation'])
  1104. else:
  1105. raise Exception('Either group dataframe of p value list is required')
  1106. elif len(colbar) == 3:
  1107. for i in xbar:
  1108. x_pos = xbar[i]
  1109. x_pos_2 = xbar[i] + bw
  1110. x_pos_3 = xbar[i] + (2 * bw)
  1111. if symb_dist:
  1112. # max value size factor is essential for rel pos of symbol
  1113. y_pos = df_mean[colbar[0]].to_numpy()[i] + df_sem[colerrorbar[0]].to_numpy()[i] + \
  1114. (max(df_mean[colbar[0]].to_numpy()) / 20) + symb_dist[i][0]
  1115. y_pos_2 = df_mean[colbar[1]].to_numpy()[i] + df_sem[colerrorbar[1]].to_numpy()[i] + \
  1116. (max(df_mean[colbar[1]].to_numpy()) / 20) + symb_dist[i][1]
  1117. y_pos_3 = df_mean[colbar[2]].to_numpy()[i] + df_sem[colerrorbar[2]].to_numpy()[i] + \
  1118. (max(df_mean[colbar[2]].to_numpy()) / 20) + symb_dist[i][2]
  1119. else:
  1120. y_pos = df_mean[colbar[0]].to_numpy()[i] + df_sem[colerrorbar[0]].to_numpy()[i] + \
  1121. (max(df_mean[colbar[0]].to_numpy()) / 20)
  1122. y_pos_2 = df_mean[colbar[1]].to_numpy()[i] + df_sem[colerrorbar[1]].to_numpy()[i] + \
  1123. (max(df_mean[colbar[1]].to_numpy()) / 20)
  1124. y_pos_3 = df_mean[colbar[2]].to_numpy()[i] + df_sem[colerrorbar[2]].to_numpy()[i] + \
  1125. (max(df_mean[colbar[2]].to_numpy()) / 20)
  1126. # group_let_df need index column
  1127. if isinstance(group_let_df, pd.DataFrame):
  1128. if y_pos > 0:
  1129. plt.annotate(group_let_df.loc[colbar[0], xbarcol[i]], xy=(x_pos, y_pos),
  1130. fontsize=sign_symbol_opts['fontsize'], ha="center",
  1131. fontfamily=sign_symbol_opts['fontname'], rotation=sign_symbol_opts['rotation'])
  1132. if y_pos_2 > 0:
  1133. plt.annotate(group_let_df.loc[colbar[1], xbarcol[i]], xy=(x_pos_2, y_pos_2),
  1134. fontsize=sign_symbol_opts['fontsize'], ha="center",
  1135. fontfamily=sign_symbol_opts['fontname'], rotation=sign_symbol_opts['rotation'])
  1136. if y_pos_3 > 0:
  1137. plt.annotate(group_let_df.loc[colbar[2], xbarcol[i]], xy=(x_pos_3, y_pos_3),
  1138. fontsize=sign_symbol_opts['fontsize'], ha="center",
  1139. fontfamily=sign_symbol_opts['fontname'], rotation=sign_symbol_opts['rotation'])
  1140. if pv:
  1141. # only if y axis is positive
  1142. if y_pos > 0:
  1143. pv_symb_1 = general.pvalue_symbol(pv[i][0], sign_symbol_opts['symbol'])
  1144. pv_symb_2 = general.pvalue_symbol(pv[i][1], sign_symbol_opts['symbol'])
  1145. pv_symb_3 = general.pvalue_symbol(pv[i][2], sign_symbol_opts['symbol'])
  1146. if pv_symb_1:
  1147. plt.annotate(pv_symb_1, xy=(x_pos, y_pos), fontsize=sign_symbol_opts['fontsize'],
  1148. ha="center", fontfamily=sign_symbol_opts['fontname'],
  1149. rotation=sign_symbol_opts['rotation'])
  1150. if pv_symb_2:
  1151. plt.annotate(pv_symb_2, xy=(x_pos_2, y_pos_2), fontsize=sign_symbol_opts['fontsize'],
  1152. ha="center", fontfamily=sign_symbol_opts['fontname'],
  1153. rotation=sign_symbol_opts['rotation'])
  1154. if pv_symb_3:
  1155. plt.annotate(pv_symb_3, xy=(x_pos_3, y_pos_3), fontsize=sign_symbol_opts['fontsize'],
  1156. ha="center", fontfamily=sign_symbol_opts['fontname'],
  1157. rotation=sign_symbol_opts['rotation'])
  1158. elif len(colbar) == 4:
  1159. for i in xbar:
  1160. x_pos = xbar[i]
  1161. x_pos_2 = xbar[i] + bw
  1162. x_pos_3 = xbar[i] + (2 * bw)
  1163. x_pos_4 = xbar[i] + (3 * bw)
  1164. if symb_dist:
  1165. # max value size factor is essential for rel pos of symbol
  1166. y_pos = df_mean[colbar[0]].to_numpy()[i] + df_sem[colerrorbar[0]].to_numpy()[i] + \
  1167. (max(df_mean[colbar[0]].to_numpy()) / 20) + symb_dist[i][0]
  1168. y_pos_2 = df_mean[colbar[1]].to_numpy()[i] + df_sem[colerrorbar[1]].to_numpy()[i] + \
  1169. (max(df_mean[colbar[1]].to_numpy()) / 20) + symb_dist[i][1]
  1170. y_pos_3 = df_mean[colbar[2]].to_numpy()[i] + df_sem[colerrorbar[2]].to_numpy()[i] + \
  1171. (max(df_mean[colbar[2]].to_numpy()) / 20) + symb_dist[i][2]
  1172. y_pos_4 = df_mean[colbar[3]].to_numpy()[i] + df_sem[colerrorbar[3]].to_numpy()[i] + \
  1173. (max(df_mean[colbar[3]].to_numpy()) / 20) + symb_dist[i][3]
  1174. else:
  1175. y_pos = df_mean[colbar[0]].to_numpy()[i] + df_sem[colerrorbar[0]].to_numpy()[i] + \
  1176. (max(df_mean[colbar[0]].to_numpy()) / 20)
  1177. y_pos_2 = df_mean[colbar[1]].to_numpy()[i] + df_sem[colerrorbar[1]].to_numpy()[i] + \
  1178. (max(df_mean[colbar[1]].to_numpy()) / 20)
  1179. y_pos_3 = df_mean[colbar[2]].to_numpy()[i] + df_sem[colerrorbar[2]].to_numpy()[i] + \
  1180. (max(df_mean[colbar[2]].to_numpy()) / 20)
  1181. y_pos_4 = df_mean[colbar[3]].to_numpy()[i] + df_sem[colerrorbar[3]].to_numpy()[i] + \
  1182. (max(df_mean[colbar[3]].to_numpy()) / 20)
  1183. # group_let_df need index column
  1184. if isinstance(group_let_df, pd.DataFrame):
  1185. # only if y axis is positive
  1186. if y_pos > 0:
  1187. plt.annotate(group_let_df.loc[colbar[0], xbarcol[i]], xy=(x_pos, y_pos),
  1188. fontsize=sign_symbol_opts['fontsize'], ha="center",
  1189. fontfamily=sign_symbol_opts['fontname'], rotation=sign_symbol_opts['rotation'])
  1190. if y_pos_2 > 0:
  1191. plt.annotate(group_let_df.loc[colbar[1], xbarcol[i]], xy=(x_pos_2, y_pos_2),
  1192. fontsize=sign_symbol_opts['fontsize'], ha="center",
  1193. fontfamily=sign_symbol_opts['fontname'], rotation=sign_symbol_opts['rotation'])
  1194. if y_pos_3 > 0:
  1195. plt.annotate(group_let_df.loc[colbar[2], xbarcol[i]], xy=(x_pos_3, y_pos_3),
  1196. fontsize=sign_symbol_opts['fontsize'], ha="center",
  1197. fontfamily=sign_symbol_opts['fontname'], rotation=sign_symbol_opts['rotation'])
  1198. if y_pos_4 > 0:
  1199. plt.annotate(group_let_df.loc[colbar[3], xbarcol[i]], xy=(x_pos_4, y_pos_4),
  1200. fontsize=sign_symbol_opts['fontsize'], ha="center",
  1201. fontfamily=sign_symbol_opts['fontname'], rotation=sign_symbol_opts['rotation'])
  1202. # need to work on this for 4 bars
  1203. if pv:
  1204. pv_symb_1 = general.pvalue_symbol(pv[i][0], sign_symbol_opts['symbol'])
  1205. pv_symb_2 = general.pvalue_symbol(pv[i][1], sign_symbol_opts['symbol'])
  1206. pv_symb_3 = general.pvalue_symbol(pv[i][2], sign_symbol_opts['symbol'])
  1207. pv_symb_4 = general.pvalue_symbol(pv[i][3], sign_symbol_opts['symbol'])
  1208. if pv_symb_1:
  1209. plt.annotate(pv_symb_1, xy=(x_pos, y_pos), fontsize=sign_symbol_opts['fontsize'],
  1210. ha="center", fontfamily=sign_symbol_opts['fontname'],
  1211. rotation=sign_symbol_opts['rotation'])
  1212. if pv_symb_2:
  1213. plt.annotate(pv_symb_2, xy=(x_pos_2, y_pos_2), fontsize=sign_symbol_opts['fontsize'],
  1214. ha="center", fontfamily=sign_symbol_opts['fontname'],
  1215. rotation=sign_symbol_opts['rotation'])
  1216. if pv_symb_3:
  1217. plt.annotate(pv_symb_3, xy=(x_pos_3, y_pos_3), fontsize=sign_symbol_opts['fontsize'],
  1218. ha="center", fontfamily=sign_symbol_opts['fontname'],
  1219. rotation=sign_symbol_opts['rotation'])
  1220. if pv_symb_4:
  1221. plt.annotate(pv_symb_4, xy=(x_pos_4, y_pos_4), fontsize=sign_symbol_opts['fontsize'],
  1222. ha="center", fontfamily=sign_symbol_opts['fontname'],
  1223. rotation=sign_symbol_opts['rotation'])
  1224. elif len(colbar) == 5:
  1225. for i in xbar:
  1226. x_pos = xbar[i]
  1227. x_pos_2 = xbar[i] + bw
  1228. x_pos_3 = xbar[i] + (2 * bw)
  1229. x_pos_4 = xbar[i] + (3 * bw)
  1230. x_pos_5 = xbar[i] + (4 * bw)
  1231. # max value size factor is essential for rel pos of symbol
  1232. if symb_dist:
  1233. y_pos = df_mean[colbar[0]].to_numpy()[i] + df_sem[colerrorbar[0]].to_numpy()[i] + \
  1234. (max(df_mean[colbar[0]].to_numpy()) / 20) + symb_dist[i][0]
  1235. y_pos_2 = df_mean[colbar[1]].to_numpy()[i] + df_sem[colerrorbar[1]].to_numpy()[i] + \
  1236. (max(df_mean[colbar[1]].to_numpy()) / 20) + symb_dist[i][1]
  1237. y_pos_3 = df_mean[colbar[2]].to_numpy()[i] + df_sem[colerrorbar[2]].to_numpy()[i] + \
  1238. (max(df_mean[colbar[2]].to_numpy()) / 20) + symb_dist[i][2]
  1239. y_pos_4 = df_mean[colbar[3]].to_numpy()[i] + df_sem[colerrorbar[3]].to_numpy()[i] + \
  1240. (max(df_mean[colbar[3]].to_numpy()) / 20) + symb_dist[i][3]
  1241. y_pos_5 = df_mean[colbar[4]].to_numpy()[i] + df_sem[colerrorbar[4]].to_numpy()[i] + \
  1242. (max(df_mean[colbar[4]].to_numpy()) / 20) + symb_dist[i][4]
  1243. else:
  1244. y_pos = df_mean[colbar[0]].to_numpy()[i] + df_sem[colerrorbar[0]].to_numpy()[i] + \
  1245. (max(df_mean[colbar[0]].to_numpy()) / 20)
  1246. y_pos_2 = df_mean[colbar[1]].to_numpy()[i] + df_sem[colerrorbar[1]].to_numpy()[i] + \
  1247. (max(df_mean[colbar[1]].to_numpy()) / 20)
  1248. y_pos_3 = df_mean[colbar[2]].to_numpy()[i] + df_sem[colerrorbar[2]].to_numpy()[i] + \
  1249. (max(df_mean[colbar[2]].to_numpy()) / 20)
  1250. y_pos_4 = df_mean[colbar[3]].to_numpy()[i] + df_sem[colerrorbar[3]].to_numpy()[i] + \
  1251. (max(df_mean[colbar[3]].to_numpy()) / 20)
  1252. y_pos_5 = df_mean[colbar[4]].to_numpy()[i] + df_sem[colerrorbar[4]].to_numpy()[i] + \
  1253. (max(df_mean[colbar[4]].to_numpy()) / 20)
  1254. # group_let_df need index column
  1255. if isinstance(group_let_df, pd.DataFrame):
  1256. # only if y axis is positive
  1257. if y_pos > 0:
  1258. plt.annotate(group_let_df.loc[colbar[0], xbarcol[i]], xy=(x_pos, y_pos),
  1259. fontsize=sign_symbol_opts['fontsize'], ha="center")
  1260. if y_pos_2 > 0:
  1261. plt.annotate(group_let_df.loc[colbar[1], xbarcol[i]], xy=(x_pos_2, y_pos_2),
  1262. fontsize=sign_symbol_opts['fontsize'], ha="center")
  1263. if y_pos_3 > 0:
  1264. plt.annotate(group_let_df.loc[colbar[2], xbarcol[i]], xy=(x_pos_3, y_pos_3),
  1265. fontsize=sign_symbol_opts['fontsize'], ha="center")
  1266. if y_pos_4 > 0:
  1267. plt.annotate(group_let_df.loc[colbar[3], xbarcol[i]], xy=(x_pos_4, y_pos_4),
  1268. fontsize=sign_symbol_opts['fontsize'], ha="center")
  1269. if y_pos_5 > 0:
  1270. plt.annotate(group_let_df.loc[colbar[4], xbarcol[i]], xy=(x_pos_5, y_pos_5),
  1271. fontsize=sign_symbol_opts['fontsize'], ha="center")
  1272. # need to work on this for 4 bars
  1273. if pv:
  1274. pv_symb_1 = general.pvalue_symbol(pv[i][0], sign_symbol_opts['symbol'])
  1275. pv_symb_2 = general.pvalue_symbol(pv[i][1], sign_symbol_opts['symbol'])
  1276. pv_symb_3 = general.pvalue_symbol(pv[i][2], sign_symbol_opts['symbol'])
  1277. if pv_symb_1:
  1278. plt.annotate(pv_symb_1, xy=(x_pos, y_pos), fontsize=sign_symbol_opts['fontsize'],
  1279. ha="center")
  1280. if pv_symb_2:
  1281. plt.annotate(pv_symb_2, xy=(x_pos_2, y_pos_2), fontsize=sign_symbol_opts['fontsize'],
  1282. ha="center")
  1283. if pv_symb_3:
  1284. plt.annotate(pv_symb_3, xy=(x_pos_3, y_pos_3), fontsize=sign_symbol_opts['fontsize'],
  1285. ha="center")
  1286. sub_cat_i = 0
  1287. if sub_cat:
  1288. if isinstance(sub_cat, dict):
  1289. for k in sub_cat:
  1290. if isinstance(k, tuple) and len(k) == 2:
  1291. cat_x_pos, cat_y_pos, cat_x_pos_2 = k[0], min_value - \
  1292. (sub_cat_opts[
  1293. 'y_neg_dist'] * size_factor_to_start_line), k[1]
  1294. plt.annotate('', xy=(cat_x_pos - (bw / 2), cat_y_pos),
  1295. xytext=(cat_x_pos_2 + (bw / 2), cat_y_pos),
  1296. arrowprops={'arrowstyle': '-', 'linewidth': 0.5}, annotation_clip=False)
  1297. if sub_cat_label_dist and isinstance(sub_cat_label_dist, list):
  1298. plt.annotate(sub_cat[k], xy=(np.mean([cat_x_pos, cat_x_pos_2]),
  1299. cat_y_pos - size_factor_to_start_line - sub_cat_label_dist[
  1300. sub_cat_i]),
  1301. ha="center", fontsize=sub_cat_opts['fontsize'], annotation_clip=False,
  1302. fontfamily=sub_cat_opts['fontname'])
  1303. sub_cat_i += 1
  1304. else:
  1305. plt.annotate(sub_cat[k], xy=(np.mean([cat_x_pos, cat_x_pos_2]),
  1306. cat_y_pos - size_factor_to_start_line),
  1307. ha="center", fontsize=sub_cat_opts['fontsize'], annotation_clip=False,
  1308. fontfamily=sub_cat_opts['fontname'])
  1309. else:
  1310. raise KeyError("Sub category keys must be tuple of size 2")
  1311. general.get_figure(show, r, figtype, figname, theme)
  1312. # for data with replicates
  1313. # deprecate dist_y_pos and dist_y_neg (repalce with size_factor_to_start_line)
  1314. @staticmethod
  1315. def singlebar(df='dataframe', dim=(6, 4), bw=0.4, colorbar='#f2aa4cff', hbsize=4, r=300, ar=(0, 0), valphabar=1,
  1316. errorbar=True, show=False, ylm=None, axtickfontsize=9, axtickfontname='Arial', ax_x_ticklabel=None,
  1317. axlabelfontsize=9, axlabelfontname='Arial', yerrlw=None, yerrcw=None, axxlabel=None, axylabel=None,
  1318. figtype='png', add_sign_line=False, pv=None,
  1319. sign_line_opts={'symbol': '*', 'fontsize': 9, 'linewidth': 0.5, 'arrowstyle': '-', 'fontname':'Arial'},
  1320. sign_line_pvals=False,
  1321. add_sign_symbol=False, sign_symbol_opts={'symbol': '*', 'fontsize': 9, 'rotation':0, 'fontname':'Arial'},
  1322. sign_line_pairs=None, sub_cat=None, sub_cat_opts={'y_neg_dist': 3.5, 'fontsize': 9, 'fontname':'Arial'},
  1323. sub_cat_label_dist=None, symb_dist=None, group_let=None, df_format=None, samp_col_name=None,
  1324. col_order=False, dotplot=False, dotsize=6, colordot=['#101820ff'], valphadot=1, markerdot='o',
  1325. sign_line_pairs_dist=None, sign_line_pv_symb_dist=None, div_fact=20, add_text=None,
  1326. figname='singlebar', connectionstyle='bar, armA=50, armB=50, angle=180, fraction=0',
  1327. std_errs_vis='both', yerrzorder=8, theme=None):
  1328. plt.rcParams['mathtext.fontset'] = 'custom'
  1329. plt.rcParams['mathtext.default'] = 'regular'
  1330. plt.rcParams['mathtext.it'] = 'Arial:italic'
  1331. plt.rcParams['mathtext.bf'] = 'Arial:italic:bold'
  1332. # set axis labels to None
  1333. _x = None
  1334. _y = None
  1335. if df_format == 'stack':
  1336. # sample_list = df[samp_col_name].unique()
  1337. if samp_col_name is None:
  1338. raise ValueError('sample column name required')
  1339. df_mean = df.groupby(samp_col_name).mean().reset_index().set_index(samp_col_name).T
  1340. df_sem = df.groupby(samp_col_name).sem().reset_index().set_index(samp_col_name).T
  1341. if col_order:
  1342. df_mean = df_mean[df[samp_col_name].unique()]
  1343. df_sem = df_sem[df[samp_col_name].unique()]
  1344. bar_h = df_mean.iloc[0]
  1345. bar_se = df_sem.iloc[0]
  1346. sample_list = df_mean.columns.to_numpy()
  1347. # get minimum from df
  1348. min_value = (0, df_mean.iloc[0].min())[df_mean.iloc[0].min() < 0]
  1349. else:
  1350. bar_h = df.describe().loc['mean']
  1351. bar_se = df.sem()
  1352. bar_counts = df.describe().loc['count']
  1353. sample_list = df.columns.to_numpy()
  1354. min_value = (0, min(df.min()))[min(df.min()) < 0]
  1355. if std_errs_vis == 'upper':
  1356. std_errs_vis = [len(bar_se)*[0], bar_se]
  1357. elif std_errs_vis == 'lower':
  1358. std_errs_vis = [bar_se, len(bar_se)*[0]]
  1359. elif std_errs_vis == 'both':
  1360. std_errs_vis = bar_se
  1361. else:
  1362. raise ValueError('In valid value for the std_errs_vis')
  1363. xbar = np.arange(len(sample_list))
  1364. color_list_bar = colorbar
  1365. if theme == 'dark':
  1366. general.dark_bg()
  1367. plt.subplots(figsize=dim)
  1368. if errorbar:
  1369. plt.bar(x=xbar, height=bar_h, yerr=std_errs_vis, width=bw, color=color_list_bar,
  1370. capsize=hbsize, alpha=valphabar, zorder=5, error_kw={'elinewidth': yerrlw, 'capthick': yerrcw,
  1371. 'zorder': yerrzorder})
  1372. else:
  1373. plt.bar(x=xbar, height=bar_h, width=bw, color=color_list_bar, capsize=hbsize, alpha=valphabar)
  1374. if ax_x_ticklabel:
  1375. x_ticklabel = ax_x_ticklabel
  1376. else:
  1377. x_ticklabel = sample_list
  1378. plt.xticks(ticks=xbar, labels=x_ticklabel, fontsize=axtickfontsize, rotation=ar[0], fontname=axtickfontname)
  1379. if axxlabel:
  1380. _x = axxlabel
  1381. if axylabel:
  1382. _y = axylabel
  1383. general.axis_labels(_x, _y, axlabelfontsize, axlabelfontname)
  1384. # ylm must be tuple of start, end, interval
  1385. if ylm:
  1386. plt.ylim(bottom=ylm[0], top=ylm[1])
  1387. plt.yticks(np.arange(ylm[0], ylm[1], ylm[2]), fontsize=axtickfontsize, fontname=axtickfontname)
  1388. plt.yticks(fontsize=axtickfontsize, rotation=ar[1], fontname=axtickfontname)
  1389. color_list_dot = colordot
  1390. if len(color_list_dot) == 1:
  1391. color_list_dot = colordot * len(sample_list)
  1392. # checked for unstacked data
  1393. if dotplot:
  1394. for cols in range(len(sample_list)):
  1395. plt.scatter(
  1396. x=np.linspace(xbar[cols] - bw / 2, xbar[cols] + bw / 2, int(bar_counts[cols])),
  1397. y=df[df.columns[cols]].dropna(), s=dotsize, color=color_list_dot[cols], zorder=10, alpha=valphadot,
  1398. marker=markerdot)
  1399. size_factor_to_start_line = max(bar_h) / div_fact
  1400. # for only adjacent bars (not for multiple bars with single control)
  1401. if add_sign_line:
  1402. for i in xbar:
  1403. if i % 2 != 0:
  1404. continue
  1405. x_pos = xbar[i]
  1406. x_pos_2 = xbar[i+1]
  1407. y_pos = df.describe().loc['mean'].to_numpy()[i] + df.sem().to_numpy()[i]
  1408. y_pos_2 = df.describe().loc['mean'].to_numpy()[i+1] + df.sem().to_numpy()[i+1]
  1409. # only if y axis is positive; in future make a function to call it (2 times used)
  1410. if y_pos > 0:
  1411. y_pos += size_factor_to_start_line
  1412. y_pos_2 += size_factor_to_start_line
  1413. pv_symb = general.pvalue_symbol(pv[int(i/2)], sign_line_opts['symbol'])
  1414. if pv_symb:
  1415. plt.annotate('', xy=(x_pos, max(y_pos, y_pos_2)), xytext=(x_pos_2, max(y_pos, y_pos_2)),
  1416. arrowprops={'connectionstyle': connectionstyle,
  1417. 'arrowstyle': sign_line_opts['arrowstyle'],
  1418. 'linewidth': sign_line_opts['linewidth']})
  1419. plt.annotate(pv_symb, xy=(np.mean([x_pos, x_pos_2]), max(y_pos, y_pos_2) +
  1420. sign_line_opts['dist_y_pos']),
  1421. fontsize=sign_line_opts['fontsize'], ha="center")
  1422. # for only adjacent bars with one control but multiple treatments
  1423. # need to work for sign_line_pairs (update df on line 1276)
  1424. p_index = 0
  1425. y_pos_dict = dict()
  1426. y_pos_dict_trt = dict()
  1427. if sign_line_pairs:
  1428. for i in sign_line_pairs:
  1429. y_pos_adj = 0
  1430. x_pos = xbar[i[0]]
  1431. x_pos_2 = xbar[i[1]]
  1432. y_pos = df.describe().loc['mean'].to_numpy()[i[0]] + df.sem().to_numpy()[i[0]]
  1433. y_pos_2 = df.describe().loc['mean'].to_numpy()[i[1]] + df.sem().to_numpy()[i[1]]
  1434. # only if y axis is positive; in future make a function to call it (2 times used)
  1435. if y_pos > 0:
  1436. y_pos += size_factor_to_start_line/2
  1437. y_pos_2 += size_factor_to_start_line/2
  1438. # check if the mean of y_pos is not lesser than not other treatments which lies between
  1439. # eg if 0-1 has higher sign bar than the 0-2
  1440. if i[0] in y_pos_dict_trt:
  1441. y_pos_adj = 1
  1442. if y_pos_2 <= y_pos_dict_trt[i[0]][1]:
  1443. if sign_line_pairs_dist:
  1444. y_pos_2 += (y_pos_dict_trt[i[0]][1] - y_pos_2) + (3 * size_factor_to_start_line) + \
  1445. sign_line_pairs_dist[p_index]
  1446. else:
  1447. y_pos_2 += (y_pos_dict_trt[i[0]][1] - y_pos_2) + (3 * size_factor_to_start_line)
  1448. elif y_pos <= y_pos_dict_trt[i[0]][0]:
  1449. if sign_line_pairs_dist:
  1450. y_pos += 3 * size_factor_to_start_line + sign_line_pairs_dist[p_index]
  1451. else:
  1452. y_pos += 3 * size_factor_to_start_line
  1453. # check if difference is not equivalent between two y_pos
  1454. # if yes add some distance, so that sign bar will not overlap
  1455. if i[0] in y_pos_dict:
  1456. y_pos_adj = 1
  1457. if 0.75 < df.describe().loc['mean'].to_numpy()[i[0]]/df.describe().loc['mean'].to_numpy()[i[1]] < 1.25:
  1458. if sign_line_pairs_dist:
  1459. y_pos += 2 * size_factor_to_start_line + sign_line_pairs_dist[p_index]
  1460. else:
  1461. y_pos += 2 * size_factor_to_start_line
  1462. if y_pos_adj == 0 and sign_line_pairs_dist:
  1463. if y_pos >= y_pos_2:
  1464. y_pos += sign_line_pairs_dist[p_index]
  1465. else:
  1466. y_pos_2 += sign_line_pairs_dist[p_index]
  1467. # sign_line_pvals passed, used p values instead of symbols
  1468. if sign_line_pvals:
  1469. pv_symb = '$\it{p}$'+ str(pv[p_index])
  1470. else:
  1471. pv_symb = general.pvalue_symbol(pv[p_index], sign_line_opts['symbol'])
  1472. y_pos_dict[i[0]] = y_pos
  1473. y_pos_dict_trt[i[0]] = [y_pos, y_pos_2]
  1474. if pv_symb:
  1475. plt.annotate('', xy=(x_pos, max(y_pos, y_pos_2)), xytext=(x_pos_2, max(y_pos, y_pos_2)),
  1476. arrowprops={'connectionstyle': connectionstyle,
  1477. 'arrowstyle': sign_line_opts['arrowstyle'],
  1478. 'linewidth': sign_line_opts['linewidth']})
  1479. # here size factor size_factor_to_start_line added instead of sign_line_opts['dist_y_pos']
  1480. # make this change everywhere in future release
  1481. plt.annotate(pv_symb, xy=(np.mean([x_pos, x_pos_2]), max(y_pos, y_pos_2) +
  1482. size_factor_to_start_line + sign_line_pv_symb_dist[p_index]),
  1483. fontsize=sign_line_opts['fontsize'], ha="center")
  1484. p_index += 1
  1485. if add_sign_symbol:
  1486. for i in xbar:
  1487. x_pos = xbar[i]
  1488. # y_pos = df.describe().loc['mean'].to_numpy()[i] + df.sem().to_numpy()[i] + size_factor_to_start_line
  1489. if symb_dist:
  1490. y_pos = bar_h.to_numpy()[i] + bar_se.to_numpy()[i] + \
  1491. size_factor_to_start_line + symb_dist[i]
  1492. else:
  1493. y_pos = bar_h.to_numpy()[i] + bar_se.to_numpy()[i] + \
  1494. size_factor_to_start_line
  1495. # group_let list
  1496. if isinstance(group_let, list):
  1497. if y_pos > 0:
  1498. plt.annotate(group_let[i], xy=(x_pos, y_pos),
  1499. fontsize=sign_symbol_opts['fontsize'], ha="center",
  1500. rotation=sign_symbol_opts['rotation'], fontfamily=sign_symbol_opts['fontname'])
  1501. # only if y axis is positive
  1502. if pv:
  1503. if y_pos > 0:
  1504. pv_symb = general.pvalue_symbol(pv[i], sign_symbol_opts['symbol'])
  1505. if pv_symb:
  1506. plt.annotate(pv_symb, xy=(x_pos, y_pos), fontsize=sign_symbol_opts['fontsize'], ha="center",
  1507. rotation=sign_symbol_opts['rotation'], fontfamily=sign_symbol_opts['fontname'])
  1508. sub_cat_i = 0
  1509. if sub_cat:
  1510. if isinstance(sub_cat, dict):
  1511. for k in sub_cat:
  1512. if isinstance(k, tuple) and len(k) == 2:
  1513. cat_x_pos, cat_y_pos, cat_x_pos_2 = k[0], min_value - \
  1514. (sub_cat_opts['y_neg_dist']*size_factor_to_start_line), k[1]
  1515. plt.annotate('', xy=(cat_x_pos-(bw/2), cat_y_pos), xytext=(cat_x_pos_2+(bw/2), cat_y_pos),
  1516. arrowprops={'arrowstyle': '-', 'linewidth': 0.5}, annotation_clip=False)
  1517. if sub_cat_label_dist and isinstance(sub_cat_label_dist, list):
  1518. plt.annotate(sub_cat[k], xy=(np.mean([cat_x_pos, cat_x_pos_2]),
  1519. cat_y_pos - size_factor_to_start_line - sub_cat_label_dist[sub_cat_i]),
  1520. ha="center", fontsize=sub_cat_opts['fontsize'], annotation_clip=False,
  1521. fontfamily=sub_cat_opts['fontname'])
  1522. sub_cat_i += 1
  1523. else:
  1524. plt.annotate(sub_cat[k], xy=(np.mean([cat_x_pos, cat_x_pos_2]),
  1525. cat_y_pos-size_factor_to_start_line),
  1526. ha="center", fontsize=sub_cat_opts['fontsize'], annotation_clip=False,
  1527. fontfamily=sub_cat_opts['fontname'])
  1528. else:
  1529. raise KeyError("Sub category keys must be tuple of size 2")
  1530. if isinstance(add_text, list):
  1531. plt.text(add_text[0], add_text[1], add_text[2], fontsize=9, fontfamily='Arial')
  1532. general.get_figure(show, r, figtype, figname, theme)
  1533. @staticmethod
  1534. def normal_bar(df='dataframe', x_col_name=None, y_col_name=None, dim=(6, 4), bw=0.4, colorbar="#f2aa4cff", r=300,
  1535. ar=(0, 0), valphabar=1, show=False, ylm=None, axtickfontsize=9, axtickfontname='Arial',
  1536. ax_x_ticklabel=None, axlabelfontsize=9, axlabelfontname='Arial', axxlabel=None, axylabel=None,
  1537. figtype='png', figname='normal_bar', theme=None):
  1538. # set axis labels to None
  1539. _x = None
  1540. _y = None
  1541. xbar = np.arange(len(df[x_col_name]))
  1542. if theme == 'dark':
  1543. general.dark_bg()
  1544. plt.subplots(figsize=dim)
  1545. plt.bar(x=xbar, height=df[y_col_name], width=bw, color=colorbar, alpha=valphabar)
  1546. if ax_x_ticklabel:
  1547. x_ticklabel = ax_x_ticklabel
  1548. else:
  1549. x_ticklabel = df[x_col_name].to_numpy()
  1550. plt.xticks(ticks=xbar, labels=x_ticklabel, fontsize=axtickfontsize, rotation=ar[0], fontname=axtickfontname)
  1551. if axxlabel:
  1552. _x = axxlabel
  1553. if axylabel:
  1554. _y = axylabel
  1555. general.axis_labels(_x, _y, axlabelfontsize, axlabelfontname)
  1556. general.get_figure(show, r, figtype, figname, theme)
  1557. def boxplot_single_factor(df='dataframe', column_names=None, grid=False, ar=(0, 0), axtickfontsize=9,
  1558. axtickfontname='Arial', dim=(6, 4), show=False, figtype='png', figname='boxplot', r=300,
  1559. ylm=None, box_line_style='-', box_line_width=1, box_line_color='b', med_line_style='-',
  1560. med_line_width=1, med_line_color='g', whisk_line_color='b', cap_color='b',
  1561. add_sign_symbol=False, symb_dist=None, sign_symbol_opts={'symbol': '*', 'fontsize': 8 },
  1562. pv=None, notch=False, outliers=True, fill_box_color=True, dotplot=False, dotsize=6,
  1563. colordot=['#101820ff'], valphadot=1, markerdot='o', theme=None):
  1564. if theme == 'dark':
  1565. general.dark_bg()
  1566. plt.subplots()
  1567. if column_names:
  1568. xbar = column_names
  1569. else:
  1570. xbar = list(df.columns)
  1571. # rot is x axis rotation
  1572. other_args = {'grid': grid, 'rot': ar[0], 'fontsize': axtickfontsize, 'notch':notch, 'showfliers':outliers,
  1573. 'figsize': dim, 'patch_artist': fill_box_color}
  1574. color_args = {'medians': med_line_color, 'boxes': box_line_color, 'whiskers': whisk_line_color,
  1575. 'caps': cap_color}
  1576. medianprops_args = {'linestyle': med_line_style, 'linewidth': med_line_width}
  1577. boxprops_args = {'linestyle': box_line_style, 'linewidth': box_line_width}
  1578. if isinstance(column_names, list):
  1579. df.boxplot(column=column_names, **other_args, boxprops=boxprops_args, medianprops=medianprops_args,
  1580. color=color_args)
  1581. else:
  1582. df.boxplot(**other_args, boxprops=boxprops_args, color=color_args, medianprops=medianprops_args)
  1583. # ylm must be tuple of start, end, interval
  1584. if ylm:
  1585. plt.ylim(bottom=ylm[0], top=ylm[1])
  1586. plt.yticks(np.arange(ylm[0], ylm[1], ylm[2]), fontsize=axtickfontsize, fontname=axtickfontname)
  1587. plt.yticks(fontsize=axtickfontsize, rotation=ar[1], fontname=axtickfontname)
  1588. color_list_dot = colordot
  1589. if len(color_list_dot) == 1:
  1590. color_list_dot = colordot * len(xbar)
  1591. # checked for unstacked data
  1592. if dotplot:
  1593. for cols in range(len(xbar)):
  1594. plt.scatter(
  1595. x=np.linspace(xbar[cols] - bw / 2, xbar[cols] + bw / 2, int(bar_counts[cols])),
  1596. y=df[df.columns[cols]].dropna(), s=dotsize, color=color_list_dot[cols], zorder=10, alpha=valphadot,
  1597. marker=markerdot)
  1598. size_factor_to_start_line = max(df.max()) / 20
  1599. if add_sign_symbol:
  1600. # p and symb_dist should be dict
  1601. if isinstance(pv, dict):
  1602. for k, v in pv.items():
  1603. if isinstance(symb_dist, dict):
  1604. if k not in symb_dist:
  1605. symb_dist[k] = 0
  1606. y_pos = df[k].max() + size_factor_to_start_line + symb_dist[k]
  1607. else:
  1608. y_pos = df[k].max() + size_factor_to_start_line
  1609. if y_pos > 0 and v <= 0.05:
  1610. pv_symb = general.pvalue_symbol(v, sign_symbol_opts['symbol'])
  1611. if pv_symb:
  1612. plt.annotate(pv_symb, xy=((xbar.index(k))+1, y_pos),
  1613. fontsize=sign_symbol_opts['fontsize'],
  1614. ha="center")
  1615. general.get_figure(show, r, figtype, figname, theme)
  1616. @staticmethod
  1617. def roc(fpr=None, tpr=None, c_line_style='-', c_line_color='#f05f21', c_line_width=1, diag_line=True,
  1618. diag_line_style='--', diag_line_width=1, diag_line_color='b', auc=None, shade_auc=False,
  1619. shade_auc_color='#f48d60',
  1620. axxlabel='False Positive Rate (1 - Specificity)', axylabel='True Positive Rate (Sensitivity)', ar=(0, 0),
  1621. axtickfontsize=9, axtickfontname='Arial', axlabelfontsize=9, axlabelfontname='Arial',
  1622. plotlegend=True, legendpos='lower right', legendanchor=None, legendcols=1, legendfontsize=8,
  1623. legendlabelframe=False, legend_columnspacing=None, per_class=False, dim=(6, 5), show=False, figtype='png',
  1624. figname='roc', r=300, ylm=None, theme=None):
  1625. if theme == 'dark':
  1626. general.dark_bg()
  1627. plt.subplots(figsize=dim)
  1628. # plt.margins(x=0)
  1629. if auc:
  1630. plt.plot(fpr, tpr, color=c_line_color, linestyle=c_line_style, linewidth=c_line_width,
  1631. label='AUC = %0.4f' % auc)
  1632. else:
  1633. plt.plot(fpr, tpr, color=c_line_color, linestyle=c_line_style, linewidth=c_line_width)
  1634. if diag_line:
  1635. plt.plot([0, 1], [0, 1], color=diag_line_color, linestyle=diag_line_style, linewidth=diag_line_width,
  1636. label='Chance level')
  1637. if per_class:
  1638. plt.plot([0, 0], [0, 1], color='grey', linestyle='-', linewidth=1)
  1639. plt.plot([0, 1], [1, 1], color='grey', linestyle='-', linewidth=1, label='Perfect performance')
  1640. # ylm must be tuple of start, end, interval
  1641. if ylm:
  1642. plt.ylim(bottom=ylm[0], top=ylm[1])
  1643. plt.yticks(np.arange(ylm[0], ylm[1], ylm[2]), fontsize=axtickfontsize, fontname=axtickfontname)
  1644. plt.yticks(fontsize=axtickfontsize, rotation=ar[1], fontname=axtickfontname)
  1645. if axxlabel:
  1646. _x = axxlabel
  1647. if axylabel:
  1648. _y = axylabel
  1649. if shade_auc:
  1650. plt.fill_between(x=fpr, y1=tpr, color=shade_auc_color)
  1651. if plotlegend:
  1652. plt.legend(loc=legendpos, bbox_to_anchor=legendanchor, ncol=legendcols, fontsize=legendfontsize,
  1653. frameon=legendlabelframe, columnspacing=legend_columnspacing)
  1654. general.axis_labels(_x, _y, axlabelfontsize, axlabelfontname)
  1655. general.get_figure(show, r, figtype, figname, theme)
  1656. class cluster:
  1657. def __init__(self):
  1658. pass
  1659. @staticmethod
  1660. def screeplot(obj="pcascree", axlabelfontsize=9, axlabelfontname="Arial", axxlabel=None,
  1661. axylabel=None, figtype='png', r=300, show=False, dim=(6, 4), theme=None):
  1662. if theme == 'dark':
  1663. general.dark_bg()
  1664. y = [x * 100 for x in obj[1]]
  1665. plt.subplots(figsize=dim)
  1666. plt.bar(obj[0], y)
  1667. xlab='PCs'
  1668. ylab='Proportion of variance (%)'
  1669. if axxlabel:
  1670. xlab = axxlabel
  1671. if axylabel:
  1672. ylab = axylabel
  1673. plt.xticks(fontsize=7, rotation=70)
  1674. general.axis_labels(xlab, ylab, axlabelfontsize, axlabelfontname)
  1675. general.get_figure(show, r, figtype, 'screeplot', theme)
  1676. @staticmethod
  1677. def pcaplot(x=None, y=None, z=None, labels=None, var1=None, var2=None, var3=None, axlabelfontsize=9,
  1678. axlabelfontname="Arial", figtype='png', r=300, show=False, plotlabels=True, dim=(6, 4), theme=None):
  1679. if theme == 'dark':
  1680. general.dark_bg()
  1681. if x is not None and y is not None and z is None:
  1682. assert var1 is not None and var2 is not None and labels is not None, "var1 or var2 variable or labels are missing"
  1683. plt.subplots(figsize=dim)
  1684. for i, varnames in enumerate(labels):
  1685. plt.scatter(x[i], y[i])
  1686. if plotlabels:
  1687. plt.text(x[i], y[i], varnames, fontsize=10)
  1688. general.axis_labels("PC1 ({}%)".format(var1), "PC2 ({}%)".format(var2), axlabelfontsize, axlabelfontname)
  1689. general.get_figure(show, r, figtype, 'pcaplot_2d', theme)
  1690. elif x is not None and y is not None and z is not None:
  1691. assert var1 and var2 and var3 and labels is not None, "var1 or var2 or var3 or labels are missing"
  1692. # for 3d plot
  1693. fig = plt.figure(figsize=dim)
  1694. ax = fig.add_subplot(111, projection='3d')
  1695. for i, varnames in enumerate(labels):
  1696. ax.scatter(x[i], y[i], z[i])
  1697. if plotlabels:
  1698. ax.text(x[i], y[i], z[i], varnames, fontsize=10)
  1699. ax.set_xlabel("PC1 ({}%)".format(var1), fontsize=axlabelfontsize, fontname=axlabelfontname)
  1700. ax.set_ylabel("PC2 ({}%)".format(var2), fontsize=axlabelfontsize, fontname=axlabelfontname)
  1701. ax.set_zlabel("PC3 ({}%)".format(var3), fontsize=axlabelfontsize, fontname=axlabelfontname)
  1702. general.get_figure(show, r, figtype, 'pcaplot_3d', theme)
  1703. @staticmethod
  1704. # adapted from https://stackoverflow.com/questions/39216897/plot-pca-loadings-and-loading-in-biplot-in-sklearn-like-rs-autoplot
  1705. def biplot(cscore=None, loadings=None, labels=None, var1=None, var2=None, var3=None, axlabelfontsize=9, axlabelfontname="Arial",
  1706. figtype='png', r=300, show=False, markerdot="o", dotsize=6, valphadot=1, colordot='#eba487', arrowcolor='#87ceeb',
  1707. valphaarrow=1, arrowlinestyle='-', arrowlinewidth=0.5, centerlines=True, colorlist=None, legendpos='best',
  1708. datapoints=True, dim=(6, 4), theme=None):
  1709. if theme == 'dark':
  1710. general.dark_bg()
  1711. assert cscore is not None and loadings is not None and labels is not None and var1 is not None and var2 is not None, \
  1712. "cscore or loadings or labels or var1 or var2 are missing"
  1713. if var1 is not None and var2 is not None and var3 is None:
  1714. xscale = 1.0 / (cscore[:, 0].max() - cscore[:, 0].min())
  1715. yscale = 1.0 / (cscore[:, 1].max() - cscore[:, 1].min())
  1716. # zscale = 1.0 / (cscore[:, 2].max() - cscore[:, 2].min())
  1717. # colorlist is an array of classes from dataframe column
  1718. plt.subplots(figsize=dim)
  1719. if datapoints:
  1720. if colorlist is not None:
  1721. unique_class = set(colorlist)
  1722. # color_dict = dict()
  1723. assign_values = {col: i for i, col in enumerate(unique_class)}
  1724. color_result_num = [assign_values[i] for i in colorlist]
  1725. if colordot and isinstance(colordot, (tuple, list)):
  1726. colour_map = ListedColormap(colordot)
  1727. # for i in range(len(list(unique_class))):
  1728. # color_dict[list(unique_class)[i]] = colordot[i]
  1729. # color_result = [color_dict[i] for i in colorlist]
  1730. s = plt.scatter(cscore[:, 0] * xscale, cscore[:, 1] * yscale, c=color_result_num, cmap=colour_map,
  1731. s=dotsize, alpha=valphadot, marker=markerdot)
  1732. plt.legend(handles=s.legend_elements()[0], labels=list(unique_class), loc=legendpos)
  1733. elif colordot and not isinstance(colordot, (tuple, list)):
  1734. # s = plt.scatter(cscore[:, 0] * xscale, cscore[:, 1] * yscale, color=color_result, s=dotsize,
  1735. # alpha=valphadot, marker=markerdot)
  1736. # plt.legend(handles=s.legend_elements()[0], labels=list(unique_class))
  1737. s = plt.scatter(cscore[:, 0] * xscale, cscore[:, 1] * yscale, c=color_result_num, s=dotsize,
  1738. alpha=valphadot, marker=markerdot)
  1739. plt.legend(handles=s.legend_elements()[0], labels=list(unique_class), loc=legendpos)
  1740. else:
  1741. plt.scatter(cscore[:, 0] * xscale, cscore[:, 1] * yscale, color=colordot, s=dotsize,
  1742. alpha=valphadot, marker=markerdot)
  1743. if centerlines:
  1744. plt.axhline(y=0, linestyle='--', color='#7d7d7d', linewidth=1)
  1745. plt.axvline(x=0, linestyle='--', color='#7d7d7d', linewidth=1)
  1746. # loadings[0] is the number of the original variables
  1747. # this is important where variables more than number of observations
  1748. for i in range(len(loadings[0])):
  1749. plt.arrow(0, 0, loadings[0][i], loadings[1][i], color=arrowcolor, alpha=valphaarrow, ls=arrowlinestyle,
  1750. lw=arrowlinewidth)
  1751. plt.text(loadings[0][i], loadings[1][i], labels[i])
  1752. # adjust_text(t)
  1753. # plt.xlim(min(loadings[0]) - 0.1, max(loadings[0]) + 0.1)
  1754. # plt.ylim(min(loadings[1]) - 0.1, max(loadings[1]) + 0.1)
  1755. xlimit_max = np.max([np.max(cscore[:, 0]*xscale), np.max(loadings[0])])
  1756. xlimit_min = np.min([np.min(cscore[:, 0]*xscale), np.min(loadings[0])])
  1757. ylimit_max = np.max([np.max(cscore[:, 1]*yscale), np.max(loadings[1])])
  1758. ylimit_min = np.min([np.min(cscore[:, 1]*yscale), np.min(loadings[1])])
  1759. plt.xlim(xlimit_min-0.2, xlimit_max+0.2)
  1760. plt.ylim(ylimit_min-0.2, ylimit_max+0.2)
  1761. general.axis_labels("PC1 ({}%)".format(var1), "PC2 ({}%)".format(var2), axlabelfontsize, axlabelfontname)
  1762. general.get_figure(show, r, figtype, 'biplot_2d', theme)
  1763. # 3D
  1764. if var1 is not None and var2 is not None and var3 is not None:
  1765. xscale = 1.0 / (cscore[:, 0].max() - cscore[:, 0].min())
  1766. yscale = 1.0 / (cscore[:, 1].max() - cscore[:, 1].min())
  1767. zscale = 1.0 / (cscore[:, 2].max() - cscore[:, 2].min())
  1768. fig = plt.figure(figsize=dim)
  1769. ax = fig.add_subplot(111, projection='3d')
  1770. if datapoints:
  1771. if colorlist is not None:
  1772. unique_class = set(colorlist)
  1773. assign_values = {col: i for i, col in enumerate(unique_class)}
  1774. color_result_num = [assign_values[i] for i in colorlist]
  1775. if colordot and isinstance(colordot, (tuple, list)):
  1776. colour_map = ListedColormap(colordot)
  1777. s = ax.scatter(cscore[:, 0]*xscale, cscore[:, 1]*yscale, cscore[:, 2]*zscale, c=color_result_num,
  1778. cmap=colour_map, s=dotsize, alpha=valphadot, marker=markerdot)
  1779. plt.legend(handles=s.legend_elements()[0], labels=list(unique_class), loc=legendpos)
  1780. elif colordot and not isinstance(colordot, (tuple, list)):
  1781. s = ax.scatter(cscore[:, 0]*xscale, cscore[:, 1]*yscale, cscore[:, 2]*zscale, c=color_result_num,
  1782. s=dotsize, alpha=valphadot, marker=markerdot)
  1783. plt.legend(handles=s.legend_elements()[0], labels=list(unique_class), loc=legendpos)
  1784. else:
  1785. ax.scatter(cscore[:, 0] * xscale, cscore[:, 1] * yscale, cscore[:, 2] * zscale, color=colordot,
  1786. s=dotsize, alpha=valphadot, marker=markerdot)
  1787. for i in range(len(loadings[0])):
  1788. ax.quiver(0, 0, 0, loadings[0][i], loadings[1][i], loadings[2][i], color=arrowcolor, alpha=valphaarrow,
  1789. ls=arrowlinestyle, lw=arrowlinewidth)
  1790. ax.text(loadings[0][i], loadings[1][i], loadings[2][i], labels[i])
  1791. xlimit_max = np.max([np.max(cscore[:, 0] * xscale), np.max(loadings[0])])
  1792. xlimit_min = np.min([np.min(cscore[:, 0] * xscale), np.min(loadings[0])])
  1793. ylimit_max = np.max([np.max(cscore[:, 1] * yscale), np.max(loadings[1])])
  1794. ylimit_min = np.min([np.min(cscore[:, 1] * yscale), np.min(loadings[1])])
  1795. zlimit_max = np.max([np.max(cscore[:, 2] * zscale), np.max(loadings[2])])
  1796. zlimit_min = np.min([np.min(cscore[:, 2] * zscale), np.min(loadings[2])])
  1797. # ax.set_xlim(min(loadings[0])-0.1, max(loadings[0])+0.1)
  1798. # ax.set_ylim(min(loadings[1])-0.1, max(loadings[1])+0.1)
  1799. # ax.set_zlim(min(loadings[2])-0.1, max(loadings[2])+0.1)
  1800. ax.set_xlim(xlimit_min-0.2, xlimit_max+0.2)
  1801. ax.set_ylim(ylimit_min-0.2, ylimit_max+0.2)
  1802. ax.set_zlim(zlimit_min-0.2, zlimit_max+0.2)
  1803. ax.set_xlabel("PC1 ({}%)".format(var1), fontsize=axlabelfontsize, fontname=axlabelfontname)
  1804. ax.set_ylabel("PC2 ({}%)".format(var2), fontsize=axlabelfontsize, fontname=axlabelfontname)
  1805. ax.set_zlabel("PC3 ({}%)".format(var3), fontsize=axlabelfontsize, fontname=axlabelfontname)
  1806. general.get_figure(show, r, figtype, 'biplot_3d', theme)
  1807. def tsneplot(score=None, axlabelfontsize=9, axlabelfontname="Arial", figtype='png', r=300, show=False,
  1808. markerdot="o", dotsize=6, valphadot=1, colordot='#4a4e4d', colorlist=None, legendpos='best',
  1809. figname='tsne_2d', dim=(6, 4), legendanchor=None, theme=None):
  1810. assert score is not None, "score are missing"
  1811. if theme == 'dark':
  1812. general.dark_bg()
  1813. plt.subplots(figsize=dim)
  1814. if colorlist is not None:
  1815. unique_class = set(colorlist)
  1816. # color_dict = dict()
  1817. assign_values = {col: i for i, col in enumerate(unique_class)}
  1818. color_result_num = [assign_values[i] for i in colorlist]
  1819. if colordot and isinstance(colordot, (tuple, list)):
  1820. colour_map = ListedColormap(colordot)
  1821. s = plt.scatter(score[:, 0], score[:, 1], c=color_result_num, cmap=colour_map,
  1822. s=dotsize, alpha=valphadot, marker=markerdot)
  1823. plt.legend(handles=s.legend_elements()[0], labels=list(unique_class), loc=legendpos,
  1824. bbox_to_anchor=legendanchor)
  1825. elif colordot and not isinstance(colordot, (tuple, list)):
  1826. s = plt.scatter(score[:, 0], score[:, 1], c=color_result_num,
  1827. s=dotsize, alpha=valphadot, marker=markerdot)
  1828. plt.legend(handles=s.legend_elements()[0], labels=list(unique_class), loc=legendpos,
  1829. bbox_to_anchor=legendanchor)
  1830. else:
  1831. plt.scatter(score[:, 0], score[:, 1], color=colordot,
  1832. s=dotsize, alpha=valphadot, marker=markerdot)
  1833. plt.xlabel("t-SNE-1", fontsize=axlabelfontsize, fontname=axlabelfontname)
  1834. plt.ylabel("t-SNE-2", fontsize=axlabelfontsize, fontname=axlabelfontname)
  1835. general.get_figure(show, r, figtype, figname, theme)