Search code examples
pythonwandb

How does one save a plot in wandb with wandb.log?


I'm trying to save a plot with wandb.log. Their docs say to do:

    wandb.log({"chart": plt})

but this fails for me.

I get two errors, 1st error (when I do NOT do plt.show() before trying to do wand.log):

Traceback (most recent call last):
  File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydevd_bundle/pydevd_exec2.py", line 3, in Exec
    exec(exp, global_vars, local_vars)
  File "<input>", line 1, in <module>
  File "/Users/brandomiranda/opt/anaconda3/envs/meta_learning/lib/python3.9/site-packages/wandb/sdk/wandb_run.py", line 256, in wrapper
    return func(self, *args, **kwargs)
  File "/Users/brandomiranda/opt/anaconda3/envs/meta_learning/lib/python3.9/site-packages/wandb/sdk/wandb_run.py", line 222, in wrapper
    return func(self, *args, **kwargs)
  File "/Users/brandomiranda/opt/anaconda3/envs/meta_learning/lib/python3.9/site-packages/wandb/sdk/wandb_run.py", line 1548, in log
    self._log(data=data, step=step, commit=commit)
  File "/Users/brandomiranda/opt/anaconda3/envs/meta_learning/lib/python3.9/site-packages/wandb/sdk/wandb_run.py", line 1339, in _log
    self._partial_history_callback(data, step, commit)
  File "/Users/brandomiranda/opt/anaconda3/envs/meta_learning/lib/python3.9/site-packages/wandb/sdk/wandb_run.py", line 1228, in _partial_history_callback
    self._backend.interface.publish_partial_history(
  File "/Users/brandomiranda/opt/anaconda3/envs/meta_learning/lib/python3.9/site-packages/wandb/sdk/interface/interface.py", line 541, in publish_partial_history
    data = history_dict_to_json(run, data, step=user_step, ignore_copy_err=True)
  File "/Users/brandomiranda/opt/anaconda3/envs/meta_learning/lib/python3.9/site-packages/wandb/sdk/data_types/utils.py", line 54, in history_dict_to_json
    payload[key] = val_to_json(
  File "/Users/brandomiranda/opt/anaconda3/envs/meta_learning/lib/python3.9/site-packages/wandb/sdk/data_types/utils.py", line 82, in val_to_json
    val = Plotly.make_plot_media(val)
  File "/Users/brandomiranda/opt/anaconda3/envs/meta_learning/lib/python3.9/site-packages/wandb/sdk/data_types/plotly.py", line 48, in make_plot_media
    val = util.matplotlib_to_plotly(val)
  File "/Users/brandomiranda/opt/anaconda3/envs/meta_learning/lib/python3.9/site-packages/wandb/util.py", line 560, in matplotlib_to_plotly
    return tools.mpl_to_plotly(obj)
  File "/Users/brandomiranda/opt/anaconda3/envs/meta_learning/lib/python3.9/site-packages/plotly/tools.py", line 112, in mpl_to_plotly
    matplotlylib.Exporter(renderer).run(fig)
  File "/Users/brandomiranda/opt/anaconda3/envs/meta_learning/lib/python3.9/site-packages/plotly/matplotlylib/mplexporter/exporter.py", line 53, in run
    self.crawl_fig(fig)
  File "/Users/brandomiranda/opt/anaconda3/envs/meta_learning/lib/python3.9/site-packages/plotly/matplotlylib/mplexporter/exporter.py", line 124, in crawl_fig
    self.crawl_ax(ax)
  File "/Users/brandomiranda/opt/anaconda3/envs/meta_learning/lib/python3.9/site-packages/plotly/matplotlylib/mplexporter/exporter.py", line 146, in crawl_ax
    self.draw_collection(ax, collection)
  File "/Users/brandomiranda/opt/anaconda3/envs/meta_learning/lib/python3.9/site-packages/plotly/matplotlylib/mplexporter/exporter.py", line 289, in draw_collection
    offset_order = offset_dict[collection.get_offset_position()]
AttributeError: 'LineCollection' object has no attribute 'get_offset_position'

I get two errors, 2nd error (when I DO plt.show() before trying to do wand.log):

Traceback (most recent call last):
  File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydevd_bundle/pydevd_exec2.py", line 3, in Exec
    exec(exp, global_vars, local_vars)
  File "<input>", line 1, in <module>
  File "/Users/brandomiranda/opt/anaconda3/envs/meta_learning/lib/python3.9/site-packages/wandb/sdk/wandb_run.py", line 256, in wrapper
    return func(self, *args, **kwargs)
  File "/Users/brandomiranda/opt/anaconda3/envs/meta_learning/lib/python3.9/site-packages/wandb/sdk/wandb_run.py", line 222, in wrapper
    return func(self, *args, **kwargs)
  File "/Users/brandomiranda/opt/anaconda3/envs/meta_learning/lib/python3.9/site-packages/wandb/sdk/wandb_run.py", line 1548, in log
    self._log(data=data, step=step, commit=commit)
  File "/Users/brandomiranda/opt/anaconda3/envs/meta_learning/lib/python3.9/site-packages/wandb/sdk/wandb_run.py", line 1339, in _log
    self._partial_history_callback(data, step, commit)
  File "/Users/brandomiranda/opt/anaconda3/envs/meta_learning/lib/python3.9/site-packages/wandb/sdk/wandb_run.py", line 1228, in _partial_history_callback
    self._backend.interface.publish_partial_history(
  File "/Users/brandomiranda/opt/anaconda3/envs/meta_learning/lib/python3.9/site-packages/wandb/sdk/interface/interface.py", line 541, in publish_partial_history
    data = history_dict_to_json(run, data, step=user_step, ignore_copy_err=True)
  File "/Users/brandomiranda/opt/anaconda3/envs/meta_learning/lib/python3.9/site-packages/wandb/sdk/data_types/utils.py", line 54, in history_dict_to_json
    payload[key] = val_to_json(
  File "/Users/brandomiranda/opt/anaconda3/envs/meta_learning/lib/python3.9/site-packages/wandb/sdk/data_types/utils.py", line 82, in val_to_json
    val = Plotly.make_plot_media(val)
  File "/Users/brandomiranda/opt/anaconda3/envs/meta_learning/lib/python3.9/site-packages/wandb/sdk/data_types/plotly.py", line 48, in make_plot_media
    val = util.matplotlib_to_plotly(val)
  File "/Users/brandomiranda/opt/anaconda3/envs/meta_learning/lib/python3.9/site-packages/wandb/util.py", line 560, in matplotlib_to_plotly
    return tools.mpl_to_plotly(obj)
  File "/Users/brandomiranda/opt/anaconda3/envs/meta_learning/lib/python3.9/site-packages/plotly/tools.py", line 112, in mpl_to_plotly
    matplotlylib.Exporter(renderer).run(fig)
  File "/Users/brandomiranda/opt/anaconda3/envs/meta_learning/lib/python3.9/site-packages/plotly/matplotlylib/mplexporter/exporter.py", line 53, in run
    self.crawl_fig(fig)
  File "/Users/brandomiranda/opt/anaconda3/envs/meta_learning/lib/python3.9/site-packages/plotly/matplotlylib/mplexporter/exporter.py", line 122, in crawl_fig
    with self.renderer.draw_figure(fig=fig, props=utils.get_figure_properties(fig)):
  File "/Users/brandomiranda/opt/anaconda3/envs/meta_learning/lib/python3.9/contextlib.py", line 119, in __enter__
    return next(self.gen)
  File "/Users/brandomiranda/opt/anaconda3/envs/meta_learning/lib/python3.9/site-packages/plotly/matplotlylib/mplexporter/renderers/base.py", line 45, in draw_figure
    self.open_figure(fig=fig, props=props)
  File "/Users/brandomiranda/opt/anaconda3/envs/meta_learning/lib/python3.9/site-packages/plotly/matplotlylib/renderer.py", line 90, in open_figure
    self.mpl_x_bounds, self.mpl_y_bounds = mpltools.get_axes_bounds(fig)
  File "/Users/brandomiranda/opt/anaconda3/envs/meta_learning/lib/python3.9/site-packages/plotly/matplotlylib/mpltools.py", line 265, in get_axes_bounds
    x_min, y_min, x_max, y_max = min(x_min), min(y_min), max(x_max), max(y_max)
ValueError: min() arg is an empty sequence

Note that their trivial example DOES work:

import matplotlib.pyplot as plt

plt.plot([1, 2, 3, 4])
plt.ylabel("some interesting numbers")
wandb.log({"chart": plt})

for me.


cross posted: https://community.wandb.ai/t/how-does-one-save-a-plot-in-wandb-with-wandb-log/2373


Solution

  • Use:

        wandb.log({"plot": wandb.Image(fig)}) 
    

    – Rylan Schaeffer


    Full working example:

    
    import numpy as np
    import matplotlib.pyplot as plt
    import wandb
    
    def generate_data() -> tuple[np.ndarray, np.ndarray]:
        # Initialize constants and data arrays
        C = np.array(
            [[7e9, 2e12], [13e9, 2e12], [34e9, 2e12], [70e9, 2e12]]
        )  # [m, Din] = [4, 2]
        m = C.shape[0]  # 4
        K = 1  # Number of columns for L_target
        
        # Create target loss values and add noise
        L_target = np.array([[1.85], [1.66], [1.55], [1.5]])
        L_target = np.repeat(L_target, K, axis=1)  # [m, K]
        assert m == L_target.shape[0], f"{m=} {L_target.shape[0]=}"
        noise = np.random.normal(loc=0.0, scale=0.01, size=(m, K))
        L_target += noise  # [m, K]
        
        return C, L_target
    
    def main():
        # Initialize Weights & Biases
        wandb.init(project='your_project_name')
    
        # Generate data
        C, L_target = generate_data()
    
        # Flatten the arrays for plotting
        compute_values = C[:, 0]  # Taking the first column for compute values
        loss_values = L_target.flatten()  # Flatten the loss values to 1D array
    
        # Create the plot
        fig, ax = plt.subplots()
        ax.plot(compute_values, loss_values, linestyle='-', marker='x', label='Random Scaling Law')
        ax.set_xscale("log")
        ax.set_yscale("linear")
        ax.set_xlabel("Log Compute (log C), C=6ND")
        ax.set_ylabel("Loss")
        ax.set_title("Scaling Law Prediction (Log-Log Plot)")
        ax.grid(True)
        ax.legend()
    
        # Log the plot as an image to wandb
        wandb.log({"plot": wandb.Image(fig)})
    
        # Display the plot
        plt.show()
    
    if __name__ == "__main__":
        main()