Search code examples
pythonoopdictionarycollectionsdefaultdict

Suppress key addition in collections.defaultdict


When a missing key is queried in a defaultdict object, the key is automatically added to the dictionary:

from collections import defaultdict

d = defaultdict(int)
res = d[5]

print(d)
# defaultdict(<class 'int'>, {5: 0})
# we want this dictionary to remain empty

However, often we want to only add keys when they are assigned explicitly or implicitly:

d[8] = 1  # we want this key added
d[3] += 1 # we want this key added

One use case is simple counting, to avoid the higher overhead of collections.Counter, but this feature may also be desirable generally.


Counter example [pardon the pun]

This is the functionality I want:

from collections import Counter
c = Counter()
res = c[5]  # 0
print(c)  # Counter()

c[8] = 1  # key added successfully
c[3] += 1 # key added successfully

But Counter is significantly slower than defaultdict(int). I find the performance hit usually ~2x slower vs defaultdict(int).

In addition, obviously Counter is only comparable to int argument in defaultdict, while defaultdict can take list, set, etc.


Is there a way to implement the above behaviour efficiently; for instance, by subclassing defaultdict?


Benchmarking example

%timeit DwD(lst)           # 72 ms
%timeit dd(lst)            # 44 ms
%timeit counter_func(lst)  # 98 ms
%timeit af(lst)            # 72 ms

Test code:

import numpy as np
from collections import defaultdict, Counter, UserDict

class DefaultDict(defaultdict):
    def get_and_forget(self, key):
        _sentinel = object()
        value = self.get(key, _sentinel)

        if value is _sentinel:
            return self.default_factory()
        return value

class DictWithDefaults(dict):
    __slots__ = ['_factory']  # avoid using extra memory

    def __init__(self, factory, *args, **kwargs):
        self._factory = factory
        super().__init__(*args, **kwargs)

    def __missing__(self, key):
        return self._factory()

lst = np.random.randint(0, 10, 100000)

def DwD(lst):
    d = DictWithDefaults(int)
    for i in lst:
        d[i] += 1
    return d

def dd(lst):
    d = defaultdict(int)
    for i in lst:
        d[i] += 1
    return d

def counter_func(lst):
    d = Counter()
    for i in lst:
        d[i] += 1
    return d

def af(lst):
    d = DefaultDict(int)
    for i in lst:
        d[i] += 1
    return d

Note Regarding Bounty Comment:

@Aran-Fey's solution has been updated since Bounty was offered, so please disregard the Bounty comment.


Solution

  • Rather than messing about with collections.defaultdict to make it do what we want, it seems easier to implement our own:

    class DefaultDict(dict):
        def __init__(self, default_factory, **kwargs):
            super().__init__(**kwargs)
    
            self.default_factory = default_factory
    
        def __getitem__(self, key):
            try:
                return super().__getitem__(key)
            except KeyError:
                return self.default_factory()
    

    This works the way you want:

    d = DefaultDict(int)
    
    res = d[5]
    d[8] = 1 
    d[3] += 1
    
    print(d)  # {8: 1, 3: 1}
    

    However, it can behave unexpectedly for mutable types:

    d = DefaultDict(list)
    d[5].append('foobar')
    
    print(d)  # output: {}
    

    This is probably the reason why defaultdict remembers the value when a nonexistant key is accessed.


    Another option is to extend defaultdict and add a new method that looks up a value without remembering it:

    from collections import defaultdict
    
    class DefaultDict(defaultdict):
        def get_and_forget(self, key):
            return self.get(key, self.default_factory())
    

    Note that the get_and_forget method calls the default_factory() every time, regardless of whether the key already exists in the dict or not. If this is undesirable, you can implement it with a sentinel value instead:

    class DefaultDict(defaultdict):
        def get_and_forget(self, key):
            _sentinel = object()
            value = self.get(key, _sentinel)
    
            if value is _sentinel:
                return self.default_factory()
            return value
    

    This has better support for mutable types, because it allows you to choose whether the value should be added to the dict or not.