Search code examples
pythonnumpymergepython-dataclassesjax

Merge dataclasses in python


I have a dataclass like:

import dataclasses
import jax.numpy as jnp

@dataclasses.dataclass
class Metric:
    score1: jnp.ndarray
    score2: jnp.ndarray
    score3: jnp.ndarray

In my code, I create multiple instances of it, is there an easy way to merge two of them attribute per attribute? For example if I have:

a = Metric(score1=jnp.array([10,10,10]), score2=jnp.array([20,20,20]), score3=jnp.array([30,30,30]))
b = Metric(score1=jnp.array([10,10,10]), score2=jnp.array([20,20,20]), score3=jnp.array([30,30,30]))

I'd like to merge them such as having a single Metric containing:

score1=jnp.array([10,10,10,10,10,10]), score2=jnp.array([20,20,20,20,20,20]) and score3=jnp.array([30,30,30,30,30,30])


Solution

  • The easiest thing is probably just to define a method:

    import dataclasses
    import jax.numpy as jnp
    
    
    @dataclasses.dataclass
    class Metric:
        score1: jnp.ndarray
        score2: jnp.ndarray
        score3: jnp.ndarray
    
        def concatenate(self, other):
            return Metric(
                jnp.concatenate((self.score1, other.score1)),
                jnp.concatenate((self.score2, other.score2)),
                jnp.concatenate((self.score3, other.score3)),
            )
    
    

    and then just do a.concatenate(b). You could also instead call the method __add__, which would make it possible just to use a + b. This is neater, but could potentially be confused with element-wise addition.