graph_cat_1.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. def plot_grouped_bars(tick_names: List[str], data_lists: List[List], \
  2. colors: List[str], gap_percent: float, axis_obj = None, \
  3. orientation: str = 'vertical', **kwargs):
  4. if len(data_lists) > len(colors):
  5. raise ValueError
  6. num_data_lists = len(data_lists)
  7. locations_centered = np.arange(len(tick_names))
  8. usable_for_bars = 1.0 - (gap_percent / 100.0)
  9. bar_width = usable_for_bars / num_data_lists
  10. center_offset = (bar_width / 2.0) * (1 - num_data_lists % 2)
  11. tick_positions = locations_centered + usable_for_bars / 2.0
  12. category_starts = locations_centered + center_offset
  13. offset = 0.0
  14. for cur_height_list, cur_color in zip(data_lists, colors):
  15. if len(tick_names) != len(cur_height_list):
  16. raise ValueError
  17. if axis_obj:
  18. # AXIS object is already provided, use it
  19. if orientation == 'vertical':
  20. plt.xticks(locations_centered, tick_names)
  21. axis_obj.bar(category_starts + offset, cur_height_list, \
  22. bar_width, color=cur_color, **kwargs)
  23. else:
  24. plt.yticks(locations_centered, tick_names)
  25. axis_obj.barh(category_starts + offset, cur_height_list, \
  26. bar_width, color=cur_color, **kwargs)
  27. else:
  28. # AXIS object is not provided, use "plt."
  29. if orientation == 'vertical':
  30. plt.xticks(locations_centered, tick_names)
  31. plt.bar(category_starts + offset, cur_height_list, bar_width, \
  32. color=cur_color, **kwargs)
  33. else:
  34. plt.yticks(locations_centered, tick_names)
  35. plt.barh(category_starts + offset, cur_height_list, bar_width, \
  36. color=cur_color, **kwargs)
  37. offset = offset - bar_width
  38. # return category_starts + (bar_width / 2.0), bar_width
  39. return locations_centered, bar_width