visualize_3_1.py 1.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. def scatterplot(
  2. df: pd.DataFrame,
  3. col: str,
  4. color: str = None,
  5. hover_name: str = None,
  6. hover_data: list = None,
  7. title="",
  8. return_figure=False,
  9. ):
  10. plot_values = np.stack(df[col], axis=1)
  11. dimension = len(plot_values)
  12. if dimension < 2 or dimension > 3:
  13. raise ValueError(
  14. "The column you want to visualize has dimension < 2 or dimension > 3."
  15. " The function can only visualize 2- and 3-dimensional data."
  16. )
  17. if dimension == 2:
  18. x, y = plot_values[0], plot_values[1]
  19. fig = px.scatter(
  20. df,
  21. x=x,
  22. y=y,
  23. color=color,
  24. hover_data=hover_data,
  25. title=title,
  26. hover_name=hover_name,
  27. )
  28. else:
  29. x, y, z = plot_values[0], plot_values[1], plot_values[2]
  30. fig = px.scatter_3d(
  31. df,
  32. x=x,
  33. y=y,
  34. z=z,
  35. color=color,
  36. hover_data=hover_data,
  37. title=title,
  38. hover_name=hover_name,
  39. )
  40. if return_figure:
  41. return fig
  42. else:
  43. fig.show()