Search code examples
djangorecursionupdatesprefetchdjango-2.0

Access updated values in prefetched self recursion set


I have a model:

class Vertex(models.Model):                                                                      
  pmap = models.ForeignKey('PMap',on_delete=models.CASCADE)                                                   
  elevation = models.FloatField(default=0)                                         
  flow_sink = models.ForeignKey(                                                                 
      'Vertex',                                                                                  
      on_delete=models.CASCADE,                                                                  
      related_name='upstream',                                                                   
      null=True)

And another model with the following function:

class PMap(models.Model):
  def compute_flux(self, save=False):
    vertices = self.vertex_set.order_by('-elevation').prefetch_related('flow_sink')

    # Init flux
    for vert in vertices:
      vert.flux = 1.0 / len(vertices)

    # Add flux from current to the downstream node.
    for vert in vertices:
      if vert.id != vert.flow_sink.id:
        vert.flow_sink.flux = vert.flux 

The function compute_flux() is supposed to add the flux value from the currently visited vertex to its flow_sink (which in turn is another vertex). It should do this recursively, so that when it reaches a vertex that has had its flux updated previously, it should yield that value to its own flow_sink.

Sadly, this doesn't work. All vertices end up with the initial flux = 1.0 / len(vertices). I think the reason is that we're updating the vertices in the prefetched set prefetch_related('flow_sink') rather than the vertices in the vertices set. Thus vert.flux in the last loop will never have any other value than the one set in the first (init) loop.

How can I fix this or work around the problem?


Solution

  • The problem is that the Vertex objects you loaded with prefetch_related are not the same objects as the one in your vertices. Yes, those two objects will be equal, a v1 == v2 check will succeed. This because Django checks if the models are the same, and the primary keys, but not the other value. Changes made to one of the Vertex objects will not reflect in the other model.

    We can solve this by maintaining a dictionary that maps the pks to the corresponding vertex, like:

    class PMap(models.Model):
    
        # ...
    
        def compute_flux(self, save=False):
            vertices = self.vertex_set.order_by('-elevation')
            ver_dic = { v.pk: v for v in vertices }
    
            for vert in vertices:
                vert.flux = 1.0 / len(vertices)
    
            # Add flux from current to the downstream node.
            for vert in vertices:
                if vert.flow_sink_id != vert.pk:
                    vertices[flow_sink_id].flux += vert.flux

    You probably also forgot to write += instead of =: the latter will have no effect, since we already assigned that value to the vertex.

    We here thus cache the Vertex objects in a dictionary, and instead of using the result of a prefetched value, we use the dictionary value.

    Here we assume that all vertices are part of vertices queryset, if that is not the case, we can still let thus work, by constructing a dictionary that "lazily" is populated as new items arrive:

    class PMap(models.Model):
    
        # ...
    
        def compute_flux(self, save=False):
            vertices = self.vertex_set.order_by('-elevation').prefetch_related('flow_sink')
            ver_dic = { v.pk: v for v in vertices }
    
            # if some flow_sink vertices are not part of the PMap
            for vert in vertices:
                if vert.flow_sink_id not in ver_dic:
                    ver_dic[vert.flow_sink_id] = vert.flow_sink
    
            for vert in vertices:
                vert.flux = 1.0 / len(vertices)
    
            # Add flux from current to the downstream node.
            for vert in vertices:
                if vert.flow_sink_id != vert.pk:
                    vertices[flow_sink_id].flux += vert.flux

    Note that if you do not save the object however, setting the .flux will have no effect, and if you later fetch the vertices again, then these vertices will of course still have the old value.