def scatterplot( df: pd.DataFrame, col: str, color: str = None, hover_name: str = None, hover_data: list = None, title="", return_figure=False, ): plot_values = np.stack(df[col], axis=1) dimension = len(plot_values) if dimension < 2 or dimension > 3: raise ValueError( "The column you want to visualize has dimension < 2 or dimension > 3." " The function can only visualize 2- and 3-dimensional data." ) if dimension == 2: x, y = plot_values[0], plot_values[1] fig = px.scatter( df, x=x, y=y, color=color, hover_data=hover_data, title=title, hover_name=hover_name, ) else: x, y, z = plot_values[0], plot_values[1], plot_values[2] fig = px.scatter_3d( df, x=x, y=y, z=z, color=color, hover_data=hover_data, title=title, hover_name=hover_name, ) if return_figure: return fig else: fig.show()