1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950 |
- def scatterplot(
- df: pd.DataFrame,
- col: str,
- color: str = None,
- hover_name: str = None,
- hover_data: list = None,
- title="",
- return_figure=False,
- ):
- plot_values = np.stack(df[col], axis=1)
- dimension = len(plot_values)
- if dimension < 2 or dimension > 3:
- raise ValueError(
- "The column you want to visualize has dimension < 2 or dimension > 3."
- " The function can only visualize 2- and 3-dimensional data."
- )
- if dimension == 2:
- x, y = plot_values[0], plot_values[1]
- fig = px.scatter(
- df,
- x=x,
- y=y,
- color=color,
- hover_data=hover_data,
- title=title,
- hover_name=hover_name,
- )
- else:
- x, y, z = plot_values[0], plot_values[1], plot_values[2]
- fig = px.scatter_3d(
- df,
- x=x,
- y=y,
- z=z,
- color=color,
- hover_data=hover_data,
- title=title,
- hover_name=hover_name,
- )
- if return_figure:
- return fig
- else:
- fig.show()
|