Search code examples
pythonpandasmulti-index

Pandas: custom class as column header with multi indexing


I'm trying to use objects as column headers in a multi indexed dataframe but I can't seem to get it to work. __eq__, __hash__ and __str__ only work for simple data frames.

Here is a small example:

class Signal:

    def __init__(self, name):
        self.name = name


    def __eq__(self, other):
        try:
            return self.name == other or self.name == other.name
        except AttributeError as err:
            return False

    def __str__(self):
        return str(self.name)

    def __hash__(self):
        return hash(self.name)

if __name__ == '__main__':
    import pandas as pd
    import numpy as np
    a = Signal('name')
    b = Signal('name2')
    c = Signal('something')

    data = {
        ('A', a): np.arange(2),
        ('A', b): np.ones(2),
        ('B', c): np.zeros(2)
    }

    df = pd.DataFrame(data)

    print(df)
    print('-----------')
    print(df['A'])

I also tried implementing __le__, __ge__ and __ne__. That did not do anything though. I don't really have a clue what else I could do. Anybody got some ideas?


Solution

  • After defining __lt__ and __gt__

    class Signal:
    
        def __init__(self, name):
            self.name = name
    
    
        def __eq__(self, other):
            try:
                return self.name == other or self.name == other.name
            except AttributeError as err:
                return False
    
        def __lt__(self, other):
            return self.name < other.name
    
        def __gt__(self, other):
            return self.name > other.name
    
        def __str__(self):
            return str(self.name)
    
        def __hash__(self):
            return hash(self.name)
    

    import pandas as pd
    import numpy as np
    a = Signal('name')
    b = Signal('name2')
    c = Signal('something')
    
    data = {
        ('A', a): np.arange(2),
        ('A', b): np.ones(2),
        ('B', c): np.zeros(2)
    }
    
    df = pd.DataFrame(data)
    
    print(df, df['A'], sep='\n\n')
    
         A               B
      name name2 something
    0    0   1.0       0.0
    1    1   1.0       0.0
    
       name  name2
    0     0    1.0
    1     1    1.0