Search code examples
pythonclassinheritancecollections

Is there a way to make a collection class of a single element class?


Assuming I have a python class A, is there a way to make a class A_Collection such that any attribute of or function applied on A can be apply to A_Collection which pass down to its individual elements? Says, for example,

\>>> a1 = A(*args1)

\>>> a2 = A(*args2)

\>>> ac = A_Collection([a1, a2])

\>>> ac.some_attribute_of_A

str([a1.some_attribute_of_A, a2.some_attribute_of_A])

\>>> import numpy as np

\>>> np.array(ac)

str([np.array(a1), np.array(a2)])

I think I have figured out the attribute inheritance, but I don't know how to deal with the functions unknown from A.

from inspect import ismethod
import numpy as np

class A():
    def __init__(self, i):
        self.i = i

    def add_one(self):
        return self.i + 1

class A_Collection(list):
    def __getattr__(self, name):
        attrs = [getattr(a, name) for a in self]
        if ismethod(attrs[0]):
            def collection(*args, **kwargs):
                return [attr(*args, **kwargs) for attr in attrs]
            return collection
        else:
            return attrs
    
ac = A_Collection([A(5), A(6), A(7)])
print(ac.i) 
 # give [5, 6, 7] 
print(ac.add_one()) 
 # give [6, 7, 8]
print(np.array(ac)) 
 # give [<__main__.A object at 0x7f516be93c10>
 #       <__main__.A object at 0x7f516be0e090>
 #       <__main__.A object at 0x7f516be0dc50>]

Solution

  • It is possible to that with the collection attributes, but not if the collections is passed as an argument to a function; there is no way to intercept that - the called function would have to know how to deal with it. But you can instead have an "apply" method on the collection, which would apply your given function to all members of itself and return a new list - that is trivial to do:

    from collections.abc import Sequence
    
    class ApplierSequence(Sequence):
        def __init__(self, initial):
            self.data = initial
            
            
        def apply(self, function):
            return type(self)([function(item) for item in self.data])
        
        def __getitem__(self, index):
            return self.data[index]
        
        
        def __len__(self):
            return len(self.data)
        
        def __getattr__(self, attr):
            data = [getattr(item, attr) for item in self.data]
            if data and callable(data[0]):
                data = [item() for item in data]
            return type(self)(data)
    
        def __repr__(self):
            return f"{self.__class__.__name__}({self.data})"