I have a classic interview problem to ask but it is with a twist:
You are given the stock prices for
n
days,a1, a2, ..., an
. You are also given an integer1 <= k <= n
. Each day, you can do one of three things:
- do nothing
- sell one unit of stock at the price of that day. (you can sell even if you don't have any stock in your portfolio as long as you are no more than
k
stocks in debt)- buy one unit of stock at the price of that day.
But, you can at any time have at most
k
stocks in your possession or you can be at mostk
stocks in debt. So, for example, if you havek
stocks, you can not buy any more stocks without selling some. Find the maximum profit you can make. Another example is that you can not sell more stocks without buying some if you are alreadyk
stocks in debt.
You want to have net zero stocks after the
n
th day. Please do this more efficiently thanO(nk)
orO(n^2)
.O(n log k)
for instance.
I can do this in O(nk)
using dynamic programming on the day number and net number of stocks in my possession. But how do I this more efficiently?
Any help would be greatly appreciated.
Thanks!
You can do this in O(n log k) time. There are different ways to think about the solution, but I think it's simplest to think of it as an optimization of the dynamic programming approach you already know about.
Let's say you want to iterate through the days, keeping track of the best cash position for each possible amount of shares held. We'll say ci is the most cash you can have holding i shares. To calculate a new day with price p, for each i you do c'i = max( ci-1-p, ci+1+p, ci ), finding the best way to get i shares by buying one, selling one, or holding respectively.
We can observe/prove some interesting properties of these cash positions:
Because of property (2), we don't need to maintain a mapping from each i to a c. We only need to maintain the set of differences, and best cash position for the smallest possible number of shares held.
Because of properties (3) and (4), the number of changes we need to make to the set of differences is bounded by a constant. Each day you just add 2 copies of p to the set of differences, and then, if you have more than 2k differences, remove the highest and lowest.
If you maintain the set of differences in a balanced BST or similar structure, then these additions and removals take log(k) time, leading to an O(n log k) algorithm overall.
Example:
Normal DP table
K = 3
Cash positions by day and num held at EOD
PRICE 1 5 2 3 9 2 4
-----------------------------------
HELD: 3 -8 -6 -6 1 2
2 -6 -3 -1 3 6 6
1 -1 -1 2 2 8 9 10
0 0 4 4 5 11 12 13
-1 1 5 6 7 14 14 16
-2 6 7 9 16 16 18
-3 8 10 18 18 20
DIFFS 5 5 9 5 4
5 5 3 5 3 4
1 5 2 3 3 3 3
1 1 2 2 3 2 3
1 1 2 2 2 2
1 1 2 2 2
Here's an implementation in Python, along with a slow DP implementation and a test to make sure they match.
This makes use of the sortedcontainers
module, which you might have to pip install
:
from sortedcontainers import SortedList
from random import randint
def maxProfit(k, dailyPrices):
minheld=0
minheldcash=0
diffs = SortedList()
for price in dailyPrices:
diffs.add(price)
diffs.add(price)
if minheld > -k:
# can only get to new minheld by selling short
minheld -= 1
minheldcash += price
else:
# minheld doesn't change. Can sell short or hold
diffs.pop() # highest
lo = diffs.pop(0) # lowest
if lo < price:
# better to sell
minheldcash = minheldcash - lo + price
# else better to hold
# calculate profit for 0 held at end
for _ in range(minheld,0):
minheldcash -= diffs.pop(0)
return minheldcash
def maxProfitSlow(k, dailyPrices):
cash = [0]*(2*k+1) # cash for holding, offset by k
minheld = 0
maxheld = 0
for price in dailyPrices:
if maxheld < k:
maxheld += 1
minheld -= 1
cash[maxheld+k] = cash[maxheld+k-1] - price
cash[minheld+k] = cash[minheld+k+1] + price
newcash = [0]*len(cash)
newcash[minheld+k] = max(cash[minheld+k], cash[minheld+k+1] + price)
newcash[maxheld+k] = max(cash[maxheld+k], cash[maxheld+k-1] - price)
for i in range(minheld+k+1, maxheld+k):
newcash[i] = max(cash[i-1]-price, cash[i], cash[i+1]+price)
cash = newcash
return cash[k]
for _ in range(1000):
k = randint(1,100)
prices = [randint(1,100) for _ in range(randint(1,100))]
if (maxProfit(k,prices) != maxProfitSlow(k,prices)):
raise "Matt is wrong" # never happens ;-)
print("success")