123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102 |
- def scatterplot(
- df: pd.DataFrame,
- col: str,
- color: str = None,
- hover_name: str = None,
- hover_data: list = None,
- title="",
- return_figure=False,
- ):
- """
- Show scatterplot of DataFrame column using python plotly scatter.
- Plot the values in column col. For example, if every cell in df[col]
- is a list of three values (e.g. from doing PCA with 3 components),
- a 3D-Plot is created and every cell entry [x, y, z] is visualized
- as the point (x, y, z).
- Parameters
- ----------
- df: DataFrame with a column to be visualized.
- col: str
- The name of the column of the DataFrame to use for x and y (and z)
- axis.
- color: str, optional, default=None
- Name of the column to use for coloring (rows with same value get same
- color).
- hover_name: str, optional, default=None
- Name of the column to supply title of hover data when hovering over a
- point.
- hover_data: List[str], optional, default=[]
- List of column names to supply data when hovering over a point.
- title: str, default to "".
- Title of the plot.
- return_figure: bool, optional, default=False
- Function returns the figure instead of showing it if set to True.
- Examples
- --------
- >>> import texthero as hero
- >>> import pandas as pd
- >>> df = pd.DataFrame(["Football, Sports, Soccer",
- ... "music, violin, orchestra", "football, fun, sports",
- ... "music, fun, guitar"], columns=["texts"])
- >>> df["texts"] = hero.clean(df["texts"]).pipe(hero.tokenize)
- >>> df["pca"] = (
- ... hero.tfidf(df["texts"])
- ... .pipe(hero.pca, n_components=3)
- ... )
- >>> df["topics"] = (
- ... hero.tfidf(df["texts"])
- ... .pipe(hero.kmeans, n_clusters=2)
- ... )
- >>> hero.scatterplot(df, col="pca", color="topics",
- ... hover_data=["texts"]) # doctest: +SKIP
- """
- plot_values = np.stack(df[col], axis=1)
- dimension = len(plot_values)
- if dimension < 2 or dimension > 3:
- raise ValueError(
- "The column you want to visualize has dimension < 2 or dimension > 3."
- " The function can only visualize 2- and 3-dimensional data."
- )
- if dimension == 2:
- x, y = plot_values[0], plot_values[1]
- fig = px.scatter(
- df,
- x=x,
- y=y,
- color=color,
- hover_data=hover_data,
- title=title,
- hover_name=hover_name,
- )
- else:
- x, y, z = plot_values[0], plot_values[1], plot_values[2]
- fig = px.scatter_3d(
- df,
- x=x,
- y=y,
- z=z,
- color=color,
- hover_data=hover_data,
- title=title,
- hover_name=hover_name,
- )
- if return_figure:
- return fig
- else:
- fig.show()
|