Search code examples
pythondictionarykarma-runnerdefaultdict

python: defaultdict with non-default argument


I want to have something like a dict of a class TestClass which has a non-default argument. When I access I don't know if the asked-for element came before already. So the TestClass:

class TestClass(object):
    def __init__(self, name):
        self.name = name
        self.state = 0
    def getName(self):
        self.state = self.state + 1
        return "%s -- %i" % (self.name, self.state)

Then the dict and the accessing function:

db = {}
def getOutput(key):
    # this is a marvel in the world of programming langauges
    if key not in db:
        db[key] = TestClass(key)
    return db[key]

And the actual testing code:

if __name__ == "__main__":
    print "testing: %s" % getOutput('charlie').getName()

Nice. But I wonder if if there is a more elegant solution. Browsing, the defaultdict comes into my mind. But this won't work, because I cannot pass an argument to the default_factory:

from collections import defaultdict
d = defaultdict(TestClass)
print "testing %s" % d['tom'].getOutput()

gives TypeError: __init__() takes exactly 2 arguments (1 given)... I is there another solution?

Besides, I wanna improve my Python. So any other suggestions are welcome as well ;-)


Solution

  • The defaultdict factory indeed does not take an argument.

    You can create your own variant that does however; the trick is in defining a __missing__ method:

    class TestClassDict(dict):
        def __missing__(self, key):
            res = self[key] = TestClass(key)
            return res
    

    Whenever dict[key] is accessed for a non-existing key, the __missing__ method is called. defaultdict uses this hook to return factory() each time, but you can provide your own and pass in key.

    Demo:

    >>> class TestClass(object):
    ...     def __init__(self, name):
    ...         self.name = name
    ...         self.state = 0
    ...     def getName(self):
    ...         self.state = self.state + 1
    ...         return "%s -- %i" % (self.name, self.state)
    ... 
    >>> class TestClassDict(dict):
    ...     def __missing__(self, key):
    ...         res = self[key] = TestClass(key)
    ...         return res
    ... 
    >>> db = TestClassDict()
    >>> db['charlie'].getName()
    'charlie -- 1'
    >>> db
    {'charlie': <__main__.TestClass object at 0x102f72250>}