graph_cat.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. import matplotlib.ticker as mtick
  4. from sweetviz import sv_math
  5. from sweetviz import utils
  6. from sweetviz.config import config
  7. from sweetviz.sv_types import FeatureType, FeatureToProcess, OTHERS_GROUPED
  8. import sweetviz.graph
  9. from typing import List
  10. def plot_grouped_bars(tick_names: List[str], data_lists: List[List], \
  11. colors: List[str], gap_percent: float, axis_obj = None, \
  12. orientation: str = 'vertical', **kwargs):
  13. if len(data_lists) > len(colors):
  14. raise ValueError
  15. num_data_lists = len(data_lists)
  16. locations_centered = np.arange(len(tick_names))
  17. usable_for_bars = 1.0 - (gap_percent / 100.0)
  18. bar_width = usable_for_bars / num_data_lists
  19. center_offset = (bar_width / 2.0) * (1 - num_data_lists % 2)
  20. tick_positions = locations_centered + usable_for_bars / 2.0
  21. category_starts = locations_centered + center_offset
  22. offset = 0.0
  23. for cur_height_list, cur_color in zip(data_lists, colors):
  24. if len(tick_names) != len(cur_height_list):
  25. raise ValueError
  26. if axis_obj:
  27. # AXIS object is already provided, use it
  28. if orientation == 'vertical':
  29. plt.xticks(locations_centered, tick_names)
  30. axis_obj.bar(category_starts + offset, cur_height_list, \
  31. bar_width, color=cur_color, **kwargs)
  32. else:
  33. plt.yticks(locations_centered, tick_names)
  34. axis_obj.barh(category_starts + offset, cur_height_list, \
  35. bar_width, color=cur_color, **kwargs)
  36. else:
  37. # AXIS object is not provided, use "plt."
  38. if orientation == 'vertical':
  39. plt.xticks(locations_centered, tick_names)
  40. plt.bar(category_starts + offset, cur_height_list, bar_width, \
  41. color=cur_color, **kwargs)
  42. else:
  43. plt.yticks(locations_centered, tick_names)
  44. plt.barh(category_starts + offset, cur_height_list, bar_width, \
  45. color=cur_color, **kwargs)
  46. offset = offset - bar_width
  47. # return category_starts + (bar_width / 2.0), bar_width
  48. return locations_centered, bar_width
  49. class GraphCat(sweetviz.graph.Graph):
  50. def __init__(self, which_graph: str, to_process: FeatureToProcess):
  51. if to_process.is_target() and which_graph == "mini":
  52. styles = ["graph_base.mplstyle", "graph_target.mplstyle"]
  53. else:
  54. styles = ["graph_base.mplstyle"]
  55. self.set_style(styles)
  56. is_detail = which_graph.find("detail") != -1
  57. cycle_colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
  58. if which_graph == "mini":
  59. max_categories = config["Graphs"].getint("summary_graph_max_categories")
  60. elif is_detail:
  61. max_categories = config["Graphs"].getint("detail_graph_max_categories")
  62. else:
  63. raise ValueError
  64. plot_data_series = utils.get_clamped_value_counts( \
  65. to_process.source_counts["value_counts_without_nan"], max_categories)
  66. if which_graph == "mini":
  67. f, axs = plt.subplots(1, 1, \
  68. figsize=(config["Graphs"].getfloat("cat_summary_graph_width"),
  69. config["Graphs"].getfloat("summary_graph_height")))
  70. gap_percent = config["Graphs"].getfloat("summary_graph_categorical_gap")
  71. axs.tick_params(axis='x', direction='out', pad=0, labelsize=8, length=2)
  72. axs.tick_params(axis='y', direction='out', pad=2, labelsize=8, length=2)
  73. axs.xaxis.tick_top()
  74. elif is_detail:
  75. height = config["Graphs"].getfloat("detail_graph_height_base") \
  76. + config["Graphs"].getfloat("detail_graph_height_per_elem") * max(1, len(plot_data_series))
  77. if height > config["Graphs"].getfloat("detail_graph_categorical_max_height"):
  78. # Shrink height to fit, past a certain number
  79. height = config["Graphs"].getfloat("detail_graph_categorical_max_height")
  80. f, axs = plt.subplots(1, 1, \
  81. figsize=(config["Graphs"].getfloat("detail_graph_width"), height))
  82. gap_percent = config["Graphs"].getfloat("detail_graph_categorical_gap")
  83. axs.tick_params(axis='x', direction='out', pad=0, labelsize=8, length=2)
  84. axs.tick_params(axis='y', direction='out', pad=2, labelsize=8, length=2)
  85. axs.xaxis.tick_top()
  86. self.size_in_inches = f.get_size_inches()
  87. tick_names = list(plot_data_series.index)
  88. # To show percentages
  89. sum_source = sum(plot_data_series)
  90. plot_data_series = plot_data_series / sum_source if sum_source != 0.0 else plot_data_series * 0.0
  91. axs.xaxis.set_major_formatter(mtick.PercentFormatter(xmax=1.0, decimals=0))
  92. # MAIN DATA (renders "under" target plots)
  93. # -----------------------------------------------------------
  94. if to_process.compare is not None:
  95. # COMPARE
  96. matched_data_series = utils.get_matched_value_counts( \
  97. to_process.compare_counts["value_counts_without_nan"],plot_data_series)
  98. # Show percentages
  99. sum_compared = sum(matched_data_series)
  100. matched_data_series = matched_data_series / sum_compared if sum_compared != 0.0 else \
  101. matched_data_series * 0.0
  102. height_lists = [list(plot_data_series.values), list(matched_data_series)]
  103. else:
  104. height_lists = [list(plot_data_series.values)]
  105. # Reorder so it plots with max values on top, "Others" at bottom
  106. # Plot: index 0 at BOTTOM
  107. # Need to change TICK NAMES and all elements in height_lists
  108. # ---------------------------------------------
  109. reversed_height_lists = list()
  110. for height_list in height_lists:
  111. reversed_height_lists.append(list(reversed(height_list)))
  112. tick_names = list(reversed(tick_names))
  113. height_lists = reversed_height_lists
  114. try:
  115. others_index = tick_names.index(OTHERS_GROUPED)
  116. tick_names.insert(0, tick_names.pop(others_index))
  117. for height_list in height_lists:
  118. height_list.insert(0, height_list.pop(others_index))
  119. except:
  120. pass
  121. # Escape LaTeX
  122. tick_names_for_labels_only = tick_names
  123. if len(tick_names):
  124. if type(tick_names[0]) == str:
  125. tick_names_for_labels_only = [str(x).replace("$",r"\$") for x in tick_names]
  126. # colors = ("r", "b")
  127. category_centers, bar_width = \
  128. plot_grouped_bars(tick_names_for_labels_only, height_lists, cycle_colors, gap_percent,
  129. orientation = 'horizontal', axis_obj = axs)
  130. # TARGET
  131. # -----------------------------------------------------------
  132. if to_process.source_target is not None:
  133. if to_process.predetermined_type_target == FeatureType.TYPE_NUM:
  134. # TARGET: IS NUMERIC
  135. target_values_source = list()
  136. names_excluding_others = [key for key in tick_names if key != OTHERS_GROUPED]
  137. for name in tick_names:
  138. if name == OTHERS_GROUPED:
  139. tick_average = to_process.source_target[ \
  140. ~to_process.source.isin(names_excluding_others)].mean()
  141. else:
  142. tick_average = to_process.source_target[ \
  143. to_process.source == name].mean()
  144. target_values_source.append(tick_average)
  145. ax2 = axs.twiny()
  146. ax2.xaxis.set_major_formatter(mtick.FuncFormatter(self.format_smart))
  147. ax2.xaxis.tick_bottom()
  148. # Need to redo this for some reason after twinning:
  149. axs.xaxis.tick_top()
  150. ax2.tick_params(axis='x', direction='out', pad=2, labelsize=8, length=2)
  151. ax2.plot(target_values_source, category_centers,
  152. marker='o', color=sweetviz.graph.COLOR_TARGET_SOURCE)
  153. if to_process.compare is not None and \
  154. to_process.compare_target is not None:
  155. # TARGET NUMERIC: with compare TARGET
  156. target_values_compare = list()
  157. for name in tick_names:
  158. if name == OTHERS_GROUPED:
  159. tick_average = to_process.compare_target[ \
  160. ~to_process.compare.isin(names_excluding_others)].mean()
  161. else:
  162. tick_average = to_process.compare_target[ \
  163. to_process.compare == name].mean()
  164. target_values_compare.append(tick_average)
  165. ax2.plot(target_values_compare,
  166. category_centers, marker='o', color=sweetviz.graph.COLOR_TARGET_COMPARE)
  167. elif to_process.predetermined_type_target == FeatureType.TYPE_BOOL:
  168. # TARGET: IS BOOL
  169. # ------------------------------------
  170. target_values_source = list()
  171. names_excluding_others = [key for key in tick_names if key != OTHERS_GROUPED]
  172. for name in tick_names:
  173. if name == OTHERS_GROUPED:
  174. tick_num = sv_math.count_fraction_of_true(to_process.source_target[ \
  175. ~to_process.source.isin(names_excluding_others)])[0]
  176. else:
  177. tick_num = sv_math.count_fraction_of_true(to_process.source_target[ \
  178. to_process.source == name])[0]
  179. target_values_source.append(tick_num)
  180. # target_values_source.append(tick_num * plot_data_series[name])
  181. # ax2 = axs.twiny()
  182. # ax2.xaxis.set_major_formatter(mtick.FuncFormatter(self.format_smart))
  183. # ax2.xaxis.tick_bottom()
  184. # # Need to redo this for some reason after twinning:
  185. # axs.xaxis.tick_top()
  186. # ax2.tick_params(axis='x', direction='out', pad=2, labelsize=8, length=2)
  187. axs.plot(target_values_source, category_centers,
  188. marker='o', color=sweetviz.graph.COLOR_TARGET_SOURCE)
  189. target_values_compare = list()
  190. if to_process.compare is not None and \
  191. to_process.compare_target is not None:
  192. # TARGET BOOL: with compare TARGET
  193. for name in tick_names:
  194. if name == OTHERS_GROUPED:
  195. tick_num = sv_math.count_fraction_of_true(to_process.compare_target[ \
  196. ~to_process.compare.isin(names_excluding_others)])[0]
  197. else:
  198. tick_num = sv_math.count_fraction_of_true(to_process.compare_target[ \
  199. to_process.compare == name])[0]
  200. target_values_compare.append(tick_num)
  201. # target_values_compare.append(tick_num * matched_data_series[name])
  202. axs.plot(target_values_compare, category_centers,
  203. marker='o', color=sweetviz.graph.COLOR_TARGET_COMPARE)
  204. # else:
  205. # # TARGET BOOL: NO compare TARGET -> Just fill with zeros so alignment is still good
  206. # for name in tick_names:
  207. # target_values_compare.append(0.0)
  208. # target_plot_series = [target_values_source, target_values_compare]
  209. # plot_grouped_bars(tick_names, target_plot_series, ('k','k'), gap_percent,
  210. # orientation='horizontal', axis_obj=axs, alpha=0.6)
  211. # Finalize Graph
  212. # -----------------------------
  213. # Needs only ~5 on right, but want to match num
  214. if which_graph == "mini":
  215. needed_pixels_padding = np.array([14.0, (300 + 32), 14, 45]) # TOP-LEFT-BOTTOM-RIGHT
  216. else:
  217. needed_pixels_padding = np.array([14.0, 140, 16, 45]) # TOP-LEFT-BOTTOM-RIGHT
  218. padding_fraction = needed_pixels_padding
  219. padding_fraction[0] = padding_fraction[0] / (self.size_in_inches[1] * f.dpi)
  220. padding_fraction[2] = padding_fraction[2] / (self.size_in_inches[1] * f.dpi)
  221. padding_fraction[3] = padding_fraction[3] / (self.size_in_inches[0] * f.dpi)
  222. padding_fraction[1] = padding_fraction[1] / (self.size_in_inches[0] * f.dpi)
  223. plt.subplots_adjust(top=(1.0 - padding_fraction[0]), left=padding_fraction[1], \
  224. bottom=padding_fraction[2], right=(1.0 - padding_fraction[3]))
  225. self.graph_base64 = self.get_encoded_base64(f)
  226. plt.close('all')