Search code examples
pythonvalueerrorjax

Jax ValueError: Incompatible shapes for broadcasting: shapes


I'm trying to write a weighted cross-entropy loss to train my model with Jax. However, I think there are some issues with my input dimension. Here are my codes:

import jax.numpy as np
from functools import partial
import jax

@partial(np.vectorize, signature="(c),(),()->()")
def weighted_cross_entropy_loss(logits, label, weights):
    one_hot_label = jax.nn.one_hot(label, num_classes=logits.shape[0])
    return -np.sum(weights* logits*one_hot_label)

logits=np.array([[1,2,3,4,5,6,7],[2,3,4,5,6,7,8]])
labels=np.array([1,2])
weights=np.array([1,2,3,4,5,6,7])
print(weighted_cross_entropy_loss(logits,label,weights))

Here are my error messages:

Traceback (most recent call last):
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 147, in broadcast_shapes
    return _broadcast_shapes_cached(*shapes)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/util.py", line 284, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/util.py", line 277, in cached
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 153, in _broadcast_shapes_cached
    return _broadcast_shapes_uncached(*shapes)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 169, in _broadcast_shapes_uncached
    raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
ValueError: Incompatible shapes for broadcasting: shapes=[(2,), (2,), (7,)]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/PATH/test.py", line 15, in <module>
    print(weighted_cross_entropy_loss(a,label,weights))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/numpy/vectorize.py", line 274, in wrapped
    broadcast_shape, dim_sizes = _parse_input_dimensions(
                                 ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/numpy/vectorize.py", line 123, in _parse_input_dimensions
    broadcast_shape = lax.broadcast_shapes(*shapes)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 149, in broadcast_shapes
    return _broadcast_shapes_uncached(*shapes)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 169, in _broadcast_shapes_uncached
    raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
ValueError: Incompatible shapes for broadcasting: shapes=[(2,), (2,), (7,)]

I'm expecting a single number that represents the cross-entropy loss between logits and labels.

I'm fairly new to this, can somebody tell me what is going on? Any help is appreciated.


Solution

  • label is length 2, and weights is length 7, which means they cannot be broadcast together.

    It's not clear to me from your question what your expected outcome was, but you can read more about how broadcasting works in NumPy (and in JAX, which implements NumPy's semantics) at https://numpy.org/doc/stable/user/basics.broadcasting.html.

    Edit: it looks like this is the operation you were aiming for:

    def weighted_cross_entropy_loss(logits, label, weights):
        one_hot_label = jax.nn.one_hot(label, num_classes=logits.shape[1])
        return -np.sum(weights * logits * one_hot_label)
    

    Since you want a single scalar output, I don't think vectorize is the right mechanism to use here.