Search code examples
algorithmsortingtime-complexitydynamic-programming

Maximum Profit for selling/buying stocks but we can only have absolute value of at most k shares


I have a classic interview problem to ask but it is with a twist:

You are given the stock prices for n days, a1, a2, ..., an. You are also given an integer 1 <= k <= n. Each day, you can do one of three things:

  1. do nothing
  2. sell one unit of stock at the price of that day. (you can sell even if you don't have any stock in your portfolio as long as you are no more than k stocks in debt)
  3. buy one unit of stock at the price of that day.

But, you can at any time have at most k stocks in your possession or you can be at most k stocks in debt. So, for example, if you have k stocks, you can not buy any more stocks without selling some. Find the maximum profit you can make. Another example is that you can not sell more stocks without buying some if you are already k stocks in debt.

You want to have net zero stocks after the nth day. Please do this more efficiently than O(nk) or O(n^2). O(n log k) for instance.

I can do this in O(nk) using dynamic programming on the day number and net number of stocks in my possession. But how do I this more efficiently?

Any help would be greatly appreciated.

Thanks!


Solution

  • You can do this in O(n log k) time. There are different ways to think about the solution, but I think it's simplest to think of it as an optimization of the dynamic programming approach you already know about.

    Let's say you want to iterate through the days, keeping track of the best cash position for each possible amount of shares held. We'll say ci is the most cash you can have holding i shares. To calculate a new day with price p, for each i you do c'i = max( ci-1-p, ci+1+p, ci ), finding the best way to get i shares by buying one, selling one, or holding respectively.

    We can observe/prove some interesting properties of these cash positions:

    1. As i increases, ci decreases monotonically -- it will always cost you more money to end up with more shares.
    2. As i increases, the difference ci - ci+1 increases. Intuitively, this is because each difference is the price of some share that you could have bought/sold, and it's always better to buy the cheapest shares first or sell the more expensive shares first. Formally, we could prove it inductively. Let's call this difference pi
    3. When the best ci comes from buying a share, the same is true for all cj > i. This is easy to prove from (2)
    4. When the best ci comes from selling a share, the same is true for all cj < i. This is similarly easy to prove from (2)

    Because of property (2), we don't need to maintain a mapping from each i to a c. We only need to maintain the set of differences, and best cash position for the smallest possible number of shares held.

    Because of properties (3) and (4), the number of changes we need to make to the set of differences is bounded by a constant. Each day you just add 2 copies of p to the set of differences, and then, if you have more than 2k differences, remove the highest and lowest.

    If you maintain the set of differences in a balanced BST or similar structure, then these additions and removals take log(k) time, leading to an O(n log k) algorithm overall.


    Example:

    Normal DP table
    K = 3
    Cash positions by day and num held at EOD
    PRICE    1   5   2   3   9   2   4
    -----------------------------------
    HELD: 3         -8  -6  -6   1   2 
          2     -6  -3  -1   3   6   6
          1 -1  -1   2   2   8   9  10
          0  0   4   4   5  11  12  13
         -1  1   5   6   7  14  14  16
         -2      6   7   9  16  16  18
         -3          8  10  18  18  20
    
    DIFFS            5   5   9   5   4
                 5   5   3   5   3   4
             1   5   2   3   3   3   3
             1   1   2   2   3   2   3
                 1   1   2   2   2   2
                     1   1   2   2   2
    
    

    Here's an implementation in Python, along with a slow DP implementation and a test to make sure they match.

    This makes use of the sortedcontainers module, which you might have to pip install:

    from sortedcontainers import SortedList
    from random import randint
    
    def maxProfit(k, dailyPrices):
        minheld=0
        minheldcash=0
        diffs = SortedList()
        for price in dailyPrices:
            diffs.add(price)
            diffs.add(price)
            if minheld > -k:
                # can only get to new minheld by selling short
                minheld -= 1
                minheldcash += price
            else:
                # minheld doesn't change.  Can sell short or hold
                diffs.pop() # highest
                lo = diffs.pop(0) # lowest
                if lo < price:
                    # better to sell
                    minheldcash = minheldcash - lo + price
                # else better to hold
    
        # calculate profit for 0 held at end
        for _ in range(minheld,0):
            minheldcash -= diffs.pop(0)
        
        return minheldcash
    
    
    def maxProfitSlow(k, dailyPrices):
        cash = [0]*(2*k+1)  # cash for holding, offset by k
        minheld = 0
        maxheld = 0
        for price in dailyPrices:
            if maxheld < k:
                maxheld += 1
                minheld -= 1
                cash[maxheld+k] = cash[maxheld+k-1] - price
                cash[minheld+k] = cash[minheld+k+1] + price
            newcash = [0]*len(cash)
            newcash[minheld+k] = max(cash[minheld+k], cash[minheld+k+1] + price)
            newcash[maxheld+k] = max(cash[maxheld+k], cash[maxheld+k-1] - price)
            for i in range(minheld+k+1, maxheld+k):
                newcash[i] = max(cash[i-1]-price, cash[i], cash[i+1]+price)
            cash = newcash
        return cash[k]
    
    
    for _ in range(1000):
        k = randint(1,100)
        prices = [randint(1,100) for _ in range(randint(1,100))]
        if (maxProfit(k,prices) != maxProfitSlow(k,prices)):
            raise "Matt is wrong" # never happens ;-)
    
    print("success")