Search code examples
vega-litevegawandb

Wandb Plotting the first element of a list


I've used the following code to log a metric

wandb.log({"metric": [4, 5, 6]})

but then found out that Wandb doesn't support plotting a list by default. I want to create a line plot where the y-axis is the first element of the metric array, and the x-axis is the step.

I've read the Custom Charts section of the document. I think I should use Vega to access the first element of the array. Here are the steps that I've taken: For the custom chart, I've set the data source as "history" and selected the "metric" key.

query {
    runSets
         (runSets: "${runSets}" ) {
            id
            name
            history
                (keys: ["metric" ] )
        }
}

In the Vega script, I tried to flatten the array, using this part of the documentation

"transform": {
    ...
    {"type": "flatten", "fields": ["${field:metric}"]},
}

This gives me a warning that "type" and "fields" arguments are not allowed, which means I should include this flattening block somewhere else (not in the transform section). I'm afraid I don't know where, and how I can achieve this. Is this even possible? If not, I think in my notebook I should write a script that accesses the wandb.run log data, and transform the data for each run. if so, any tips for that solution is also appreciated.


Solution

  • EDIT: A better solution for this problem would be the Wandb "Weave Table". Check here for the complete explanation.

    As others have pointed out, the correct way to log a list for it to be plotted correctly with Wandb is to log each item separately.

    But my problem was that I had already trained a model while logging the metrics as lists, and I didn't want to train the model all over again, to have the plots be displayed correctly.

    Wandb API supports retrieving log history for a run, and adding new logs to it. Using this approach, we can retrieve the history for the metric that is logged incorrectly as a list, and basically re-log it with the correct format (=each item of the list separately).

    As an example, The following code block can be used to plot an already-logged list metric. It assumes that a run for a project called project_name, has logged a key named metric as a list. It retrieves the log history and plots the result as a new (properly working) plot in Wandb.

    import matplotlib.pyplot as plt
    import wandb
    
    runs = api.runs('workspace/project_name')
    
    for i in range(len(runs))[:1]:
        run_history = runs[i].scan_history(keys=["metric"])
        values = [row['metric'] for row in run_history]
    
        plt.plot(values)
        plt.ylabel("metric")
        wandb.init(
            project="project_name",
            id=run_history.run.id,
            resume=True
        )
        wandb.log({"metric_plot": plt})
    

    To avoid this issue in the future, I use the following function. It makes sure that all lists in a dictionary are converted to distinct items.

    from copy import deepcopy
    
    def to_wandb_format(d: dict) -> dict:
        """
        Unpack list values in the dictionary, as wandb can't plot list values.
        Example:
            Input: {"metric": [99, 88, 77]}
            Output: {"metric_0": 99, "metric_1": 88, "metric_2": 77}
        """
        new_d = deepcopy(d)
        for key, val in d.items():
            if isinstance(val, list):
                new_d.pop(key)
                new_d.update({
                    f'{key}_{i}': v for i, v in enumerate(val)
                })
    
        return new_d
    

    Which can be used as follows:

    wandb.log(to_wandb_format(epoch_train_metrics))
    

    assuming epoch_train_metrics is a dictionary that may have lists as values.