Search code examples
pythonpython-dataclasses

Extract all field names from nested dataclasses


I have a dataclass that contains within it another dataclass:

@dataclass
class A:
    var_1: str
    var_2: int

@dataclass
class B:
    var_3: float
    var_4: A

I would like to create a list of all field names for attributes that aren't dataclasses, and if the attribute is a dataclass the to list the attributes of that class, so in this case the output would be ['var_3', 'var_1', 'var_2'] I know it's possible to use dataclasses.fields to get the fields of a simple dataclass, but I can't work out how to recursively do it for nested dataclasses. Ideally I would like to be able to do it by just passing the class type B (in the same way you can pass the type to dataclasses.fields), rather than an instance of B. Is it possible to do this?

Thank you!


Solution

  • Use dataclasses.fields() to iterate over all the fields, making a list of their names.

    Use dataclasses.is_dataclass() to tell if a field is a nested dataclass. If so, recurse into it instead of adding its name to the list.

    from dataclasses import fields, is_dataclass
    
    def all_fields(c: type) -> list[str]:
        field_list = []
        for f in fields(c):
            if is_dataclass(f.type):
                field_list.extend(all_fields(f.type))
            else:
                field_list.append(f.name)
        return field_list