Search code examples
pythongraphdepth-first-search

How to handle count in recursion


import unittest

def riverSizes(test_input):
   l=[]
   river_sizes = 0
   for i in range(len(test_input)):
       for j in range(len(test_input[i])):
           if test_input[i][j] == 1:
               count = traverse_river(i, j, len(test_input), len(test_input[0]), test_input, 0)
               l.append(count)
               river_sizes += 1
   print(river_sizes,l)

def traverse_river(x, y, max_row, max_col, matrix, count):
    if x < 0 or x >= max_row or y < 0 or y >= max_col or matrix[x][y] != 1:
       return count
    matrix[x][y] = 2
    traverse_river(x - 1, y, max_row, max_col, matrix, count + 1)
    traverse_river(x + 1, y, max_row, max_col, matrix, count + 1)
    traverse_river(x, y + 1, max_row, max_col, matrix, count + 1)
    
class TestProgram(unittest.TestCase):
    def test_case_1(self):
       test_input = [[1, 0, 0, 1, 0], [1, 0, 1, 0, 0], [0, 0, 1, 0, 1], [1, 0, 1, 0, 1], [1, 0, 1, 1, 0]]
       expected = [1, 2, 2, 2, 5]
       riverSizes(test_input)
       self.assertEqual(sorted(riverSizes(test_input)), expected)

I am trying to run dfs on the matrix and keep track of all the connected components with "1" along with the length count of connected ones. How to handle count in recusrrsion


Solution

  • A few issues with the code:

    • The function riverSizes does not return anything, so applying sorted on it will not help either.

    • The recursion should also consider y - 1, as it may be that there is a U-shape, like this pattern:

        1 0 1
        1 1 1
      

      ...and you don't want to miss out on the 1 in the top-right corner.

    • The function traverse_river does not always return a count -- only if the first if condition is true.

    Moreover, you don't need:

    • to pass count as an argument. Just add to the returned count to so accumulate that value.
    • to maintain river_sizes, as that is duplicate information: it is actually equal to len(l).

    Here is how it could work:

    def riverSizes(test_input):
       l = []
       for i in range(len(test_input)):
           for j in range(len(test_input[i])):
               if test_input[i][j] == 1:
                   count = traverse_river(i, j, len(test_input), len(test_input[0]), test_input)
                   l.append(count)
       return l
    
    def traverse_river(x, y, max_row, max_col, matrix):
        if x < 0 or x >= max_row or y < 0 or y >= max_col or matrix[x][y] != 1:
           return 0
        matrix[x][y] = 2
        return (1 + traverse_river(x - 1, y, max_row, max_col, matrix)
                  + traverse_river(x + 1, y, max_row, max_col, matrix)
                  + traverse_river(x, y + 1, max_row, max_col, matrix)
                  + traverse_river(x, y - 1, max_row, max_col, matrix))
        
    test_input = [[1, 0, 0, 1, 0], [1, 0, 1, 0, 0], [0, 0, 1, 0, 1], [1, 0, 1, 0, 1], [1, 0, 1, 1, 0]]
    print(riverSizes(test_input))