I have greedy algoithm with job scheduling problem, but I want to return which one projects were chosen to get this max value, how can I do that?
from dataclasses import dataclass
from datetime import date
@dataclass
class InvestmentProject:
profit: int
begin_date: date
end_date: date
def get_max_sequence(arr, i=0):
if i == len(arr):
return 0
j = i + 1
while j < len(arr) and arr[i].end_date > arr[j].begin_date:
j += 1
one = arr[i].profit + get_max_sequence(arr, j)
two = get_max_sequence(arr, i+1)
return max(one, two)
def main():
arr = [
InvestmentProject(30, date(2022, 10, 10), date(2022, 10, 14)),
InvestmentProject(15, date(2022, 10, 15), date(2022, 10, 16)),
InvestmentProject(25, date(2022, 10, 12), date(2022, 10, 15)),
InvestmentProject(10, date(2022, 10, 20), date(2022, 10, 26)),
]
print(get_max_sequence(sorted(arr, key=lambda x: x.begin_date)))
You could always return value and list of indexes.
First
if i == len(arr):
return 0, []
Next you would have to always get value and list before calculations
val, indexes = get_max_sequence(arr, j)
one = arr[i].profit + val
two, other_indexes = get_max_sequence(arr, i+1)
And you would have to manually check max()
if one > two:
return one, indexes + [i]
else:
return two, other_indexes
Minimal working code:
from dataclasses import dataclass
from datetime import date
@dataclass
class InvestmentProject:
profit: int
begin_date: date
end_date: date
def get_max_sequence(arr, i=0):
if i == len(arr):
return 0, []
j = i + 1
while j < len(arr) and arr[i].end_date > arr[j].begin_date:
j += 1
val, indexes = get_max_sequence(arr, j)
one = arr[i].profit + val
two, other_indexes = get_max_sequence(arr, i+1)
if one > two:
print('one:', indexes+[i])
return one, indexes + [i]
else:
print('two:', other_indexes)
return two, other_indexes
def main():
arr = [
InvestmentProject(30, date(2022, 10, 10), date(2022, 10, 14)),
InvestmentProject(15, date(2022, 10, 15), date(2022, 10, 16)),
InvestmentProject(25, date(2022, 10, 12), date(2022, 10, 15)),
InvestmentProject(10, date(2022, 10, 20), date(2022, 10, 26)),
]
arr = list(sorted(arr, key=lambda x: x.begin_date))
for item in arr:
print(item)
print(get_max_sequence(arr))
main()