Search code examples
pythoncomputer-sciencegreedy

How to get solution path from greedy algorithm?


I have greedy algoithm with job scheduling problem, but I want to return which one projects were chosen to get this max value, how can I do that?

from dataclasses import dataclass
from datetime import date


@dataclass
class InvestmentProject:
    profit: int
    begin_date: date
    end_date: date


def get_max_sequence(arr, i=0):
    if i == len(arr):
        return 0
    j = i + 1
    while j < len(arr) and arr[i].end_date > arr[j].begin_date:
        j += 1
    one = arr[i].profit + get_max_sequence(arr, j)
    two = get_max_sequence(arr, i+1)
    return max(one, two)


def main():
    arr = [
        InvestmentProject(30, date(2022, 10, 10), date(2022, 10, 14)),
        InvestmentProject(15, date(2022, 10, 15), date(2022, 10, 16)),
        InvestmentProject(25, date(2022, 10, 12), date(2022, 10, 15)),
        InvestmentProject(10, date(2022, 10, 20), date(2022, 10, 26)),
    ]
    print(get_max_sequence(sorted(arr, key=lambda x: x.begin_date)))


Solution

  • You could always return value and list of indexes.

    First

    if i == len(arr):
       return 0, []
    

    Next you would have to always get value and list before calculations

        val, indexes = get_max_sequence(arr, j)
    
        one = arr[i].profit + val
        
        two, other_indexes = get_max_sequence(arr, i+1)
    

    And you would have to manually check max()

        if one > two:
            return one, indexes + [i]
        else:
            return two, other_indexes
    

    Minimal working code:

    from dataclasses import dataclass
    from datetime import date
    
    
    @dataclass
    class InvestmentProject:
        profit: int
        begin_date: date
        end_date: date
    
    
    def get_max_sequence(arr, i=0):
        if i == len(arr):
            return 0, []
        
        j = i + 1
        while j < len(arr) and arr[i].end_date > arr[j].begin_date:
            j += 1
            
        val, indexes = get_max_sequence(arr, j)
        one = arr[i].profit + val
        
        two, other_indexes = get_max_sequence(arr, i+1)
        
        if one > two:
            print('one:', indexes+[i])
            return one, indexes + [i]
        else:
            print('two:', other_indexes)
            return two, other_indexes
    
    
    def main():
        arr = [
            InvestmentProject(30, date(2022, 10, 10), date(2022, 10, 14)),
            InvestmentProject(15, date(2022, 10, 15), date(2022, 10, 16)),
            InvestmentProject(25, date(2022, 10, 12), date(2022, 10, 15)),
            InvestmentProject(10, date(2022, 10, 20), date(2022, 10, 26)),
        ]
        
        arr = list(sorted(arr, key=lambda x: x.begin_date))
        for item in arr:
            print(item)
        print(get_max_sequence(arr))
        
    main()