visualize_3_1.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. def scatterplot(
  2. df: pd.DataFrame,
  3. col: str,
  4. color: str = None,
  5. hover_name: str = None,
  6. hover_data: list = None,
  7. title="",
  8. return_figure=False,
  9. ):
  10. """
  11. Show scatterplot of DataFrame column using python plotly scatter.
  12. Plot the values in column col. For example, if every cell in df[col]
  13. is a list of three values (e.g. from doing PCA with 3 components),
  14. a 3D-Plot is created and every cell entry [x, y, z] is visualized
  15. as the point (x, y, z).
  16. Parameters
  17. ----------
  18. df: DataFrame with a column to be visualized.
  19. col: str
  20. The name of the column of the DataFrame to use for x and y (and z)
  21. axis.
  22. color: str, optional, default=None
  23. Name of the column to use for coloring (rows with same value get same
  24. color).
  25. hover_name: str, optional, default=None
  26. Name of the column to supply title of hover data when hovering over a
  27. point.
  28. hover_data: List[str], optional, default=[]
  29. List of column names to supply data when hovering over a point.
  30. title: str, default to "".
  31. Title of the plot.
  32. return_figure: bool, optional, default=False
  33. Function returns the figure instead of showing it if set to True.
  34. Examples
  35. --------
  36. >>> import texthero as hero
  37. >>> import pandas as pd
  38. >>> df = pd.DataFrame(["Football, Sports, Soccer",
  39. ... "music, violin, orchestra", "football, fun, sports",
  40. ... "music, fun, guitar"], columns=["texts"])
  41. >>> df["texts"] = hero.clean(df["texts"]).pipe(hero.tokenize)
  42. >>> df["pca"] = (
  43. ... hero.tfidf(df["texts"])
  44. ... .pipe(hero.pca, n_components=3)
  45. ... )
  46. >>> df["topics"] = (
  47. ... hero.tfidf(df["texts"])
  48. ... .pipe(hero.kmeans, n_clusters=2)
  49. ... )
  50. >>> hero.scatterplot(df, col="pca", color="topics",
  51. ... hover_data=["texts"]) # doctest: +SKIP
  52. """
  53. plot_values = np.stack(df[col], axis=1)
  54. dimension = len(plot_values)
  55. if dimension < 2 or dimension > 3:
  56. raise ValueError(
  57. "The column you want to visualize has dimension < 2 or dimension > 3."
  58. " The function can only visualize 2- and 3-dimensional data."
  59. )
  60. if dimension == 2:
  61. x, y = plot_values[0], plot_values[1]
  62. fig = px.scatter(
  63. df,
  64. x=x,
  65. y=y,
  66. color=color,
  67. hover_data=hover_data,
  68. title=title,
  69. hover_name=hover_name,
  70. )
  71. else:
  72. x, y, z = plot_values[0], plot_values[1], plot_values[2]
  73. fig = px.scatter_3d(
  74. df,
  75. x=x,
  76. y=y,
  77. z=z,
  78. color=color,
  79. hover_data=hover_data,
  80. title=title,
  81. hover_name=hover_name,
  82. )
  83. if return_figure:
  84. return fig
  85. else:
  86. fig.show()