Search code examples
pythondictionarytreeminimax

large number of changing dictionaries for tree nodes in python


I am trying to implement a hierarchical tree structure in which every node has a slightly changed version of a dictionary. My problem is that, unlike similar structures in R, python dictionaries are just labels for external variables rather than true 'value containers'. Therefore, any changes made at one of the nodes affect the dictionary of all other nodes as well.

Given this behaviour of dict, what would be the proper way to implement this in python? It seems like a common approach so I feel like I must be missing something but have been banging my head against the wall for hours now.

Background: I am trying to implement a Minimax approach for a turn-based perfect information adversarial board game in python using a dictionary for the board state. I create a hierarchical tree structure of nodes and children based on all possible moves, and so far I have been trying to modify the dictionary for each node. The true nature of dictionaries in python was unclear to me so I was fighting with strange results from my approach because the changes from every single node were applied to all other nodes as well.

Example

#Create some original state (dict of dicts)
state_original = {'field1' : {'player':1, 'count':2}, 'field2' : {'player':2, 'state': 4}}
print(state_original)

#Define object for the tree nodes
class Node(object):
  def __init__(self, depth, state, field=None):
    self.depth = depth
    self.state = state
    self.field = field
    self.subnodes = []
    if self.depth > 0:
      self.SpawnSubnodes()
  def SpawnSubnodes(self):
    for field in self.state:
      depth_new = self.depth -1
      state_new = self.state
      state_new[field]['count'] += 1
      self.subnodes.append(Node(depth_new, state_new, field))

#Build tree
nodes = Node(3, state_original)
nodes.subnodes

#But: This is a mess now
print(state_original)

#This is a mess, too. Results are not meaningful :( 
print(nodes.subnodes[1].state)

It works with deepcopy, but is too slow for my (larger) tree

from copy import deepcopy

#Define object for the tree nodes
class Node(object):
  def __init__(self, depth, state, field=None):
    self.depth = depth
    self.state = state
    self.field = field
    self.subnodes = []
    if self.depth > 0:
      self.SpawnSubnodes()
  def SpawnSubnodes(self):
    for field in self.state:
      depth_new = self.depth -1
      state_new = deepcopy(self.state)
      state_new[field]['count'] += 1
      self.subnodes.append(Node(depth_new, state_new, field))

Edit: I realised that copy does not work for me because my board state is a dictionary of dictionaries rather than a simple dictionary. I have updated my example code to accurately reflect this. While a potential workaround would be do try and come up with a simpler representation of the board (probably splitting it up in a "board shape" dict and a "board state" dict), I feel like there should be a more pythonic way to solve this problem?


Solution

  • Instead of copy.deepcopy, use copy.copy (shallow copy) because you don't really need deep copy.

    import copy
    
    #Define object for the tree nodes
    class Node(object):
      def __init__(self, depth, state, field=None):
        self.depth = depth
        self.state = state
        self.field = field
        self.subnodes = []
        if self.depth > 0:
          self.SpawnSubnodes()
      def SpawnSubnodes(self):
        for field in self.state:
          depth_new = self.depth -1
          state_new = copy.copy(self.state)
          state_new[field] += 1
          self.subnodes.append(Node(depth_new, state_new, field))
    

    A shallow copy is much faster than deep copy. Here's a simple timing test:

    In [5]: %timeit copy.deepcopy(state_original)
    The slowest run took 6.96 times longer than the fastest. This could mean that an intermediate result is being cached 
    100000 loops, best of 3: 4.97 µs per loop
    
    In [6]: %timeit copy.copy(state_original)
    The slowest run took 8.84 times longer than the fastest. This could mean that an intermediate result is being cached 
    1000000 loops, best of 3: 709 ns per loop
    

    Note: the above solution works only when the dict in question is simple, i.e., it does not contain other dicts.

    In case the dict to begin with contains other simple-dicts, performing a shallow copy of its contents iteratively can be faster than a deepcopy operation.

    def mycopy(d):
        return {k: copy.copy(v) for k, v in d.items()}
    

    A preliminary performance analysis of mycopy gives me roughly an order-of-magnitude improvement over copy.deepcopy.