I have a dataclass like:
import dataclasses
import jax.numpy as jnp
@dataclasses.dataclass
class Metric:
score1: jnp.ndarray
score2: jnp.ndarray
score3: jnp.ndarray
In my code, I create multiple instances of it, is there an easy way to merge two of them attribute per attribute? For example if I have:
a = Metric(score1=jnp.array([10,10,10]), score2=jnp.array([20,20,20]), score3=jnp.array([30,30,30]))
b = Metric(score1=jnp.array([10,10,10]), score2=jnp.array([20,20,20]), score3=jnp.array([30,30,30]))
I'd like to merge them such as having a single Metric containing:
score1=jnp.array([10,10,10,10,10,10]), score2=jnp.array([20,20,20,20,20,20]) and score3=jnp.array([30,30,30,30,30,30])
The easiest thing is probably just to define a method:
import dataclasses
import jax.numpy as jnp
@dataclasses.dataclass
class Metric:
score1: jnp.ndarray
score2: jnp.ndarray
score3: jnp.ndarray
def concatenate(self, other):
return Metric(
jnp.concatenate((self.score1, other.score1)),
jnp.concatenate((self.score2, other.score2)),
jnp.concatenate((self.score3, other.score3)),
)
and then just do a.concatenate(b)
. You could also instead call the method __add__
, which would make it possible just to use a + b
. This is neater, but could potentially be confused with element-wise addition.