Search code examples
pythondefault-valuepython-dataclasses

Set dataclass variable to double the default value


This is probably too much of a perfectionist question but.. I'd like to instantiate a dataclass with double the default value. So that if I change the default value, instantiating it will provide double the usual default. I'm hoping there is a better way to do this than simply reading and updating the values.

Here is a simple example to explain what I want:

from dataclasses import dataclass

@dataclass
class Value:
    variable1: int = 1
    variable2: int = 2
    variable3: int = 3    

value = Value()
doubled_value = Value() {* 2 for all}

print(value.variable1)
print(value.variable2)
print(value.variable3)

print(doubled_value.variable1)
print(doubled_value.variable2)
print(doubled_value.variable3)

>>> 1
>>> 2
>>> 3

>>> 2
>>> 4
>>> 6

This works but is hard to understand and ugly:

from dataclasses import dataclass

@dataclass
class Value:
    variable1: int = 1
    variable2: int = 2
    variable3: int = 3    

value = Value()
doubled_value = Value()

for i, var in enumerate(doubled_value.__dataclass_fields__.items()):
    doubled_value.__setattr__(var[0], var[1].default*2)

print(value.variable1)
print(value.variable2)
print(value.variable3)

print(doubled_value.variable1)
print(doubled_value.variable2)
print(doubled_value.variable3)

>>> 1
>>> 2
>>> 3

>>> 2
>>> 4
>>> 6

To be clear, I want it to work regardless of default value so doubled_value = Value(2, 4, 6) won't work

Thank you for any help. This is my first post on stack overflow so please let me know if I've made any mistakes.


Solution

  • Are you looking for something like this?

    from dataclasses import dataclass, fields
    
    @dataclass
    class Value:
        variable1: int = 1
        variable2: int = 2
        variable3: int = 3    
    
        @classmethod
        def times(cls, times=2):
            return cls(*(f.default * times for f in fields(cls)))
    
    value = Value()
    doubled_value = Value.times()
    tripled_value = Value.times(3)
    

    fields allows you to obtain a list of fields and their default values, which I'm simply multiplying in a generator expression, which is unpacked to arguments to the class.