Search code examples
algorithmhaskellfunctional-programming

"Time Limit Exceeded" for a Haskell solution to Knapsack


This is a very standard Knapsack problem from Kattis. Below is a straightforward dynamic programming solution in Haskell:

{-# Language OverloadedStrings #-}

import Control.Arrow ((>>>))
import Data.List (intercalate)
import Data.Array
import Data.Maybe
import qualified Data.ByteString.Lazy.Char8 as C

main = C.interact solve

solve = C.words >>> fmap readInt >>> divideInput
        >>> fmap (solveCase >>> toBS)
        >>> C.unlines
  where readInt = C.readInt >>> fromJust >>> fst

divideInput :: [Int] -> [[Int]]
divideInput [] = []
divideInput (c:n:ls) = (c : n : this) : divideInput that
  where (this, that) = splitAt (2*n) ls

solveCase :: [Int] -> [[Int]]
solveCase (c:n:os) = [[length is], is]
  where is = recover (n, c) []

        recover (i, j) rs | table ! (i, j) == 0 = rs
                          | table ! (i, j) == table ! (i-1, j) =
                            recover (i-1, j) rs
                          | table ! (i, j) == vi + (table ! (i-1, j-wi)) =
                            recover (i-1, j-wi) ((i-1):rs)
          where (vi, wi) = objs ! i

        objs :: Array Int (Int, Int)
        objs = listArray (1, n) $ pairs os
        pairs [] = []
        pairs (v:w:os) = (v,w) : pairs os

        -- table[i][j] is the max value that can be achieved with
        -- objects [1..i] where the total weight of selected
        -- objects is <= j.
        table :: Array (Int, Int) Int
        table = array bnds [(ij, fill ij) | ij <- range bnds]
          where 
            bnds = ((0,0), (n,c))
            fill (i, w) | i == 0 || w == 0 = 0
                        | w < wi = vx
                        | otherwise = max vx (vy+vi)
              where vx = table ! (i-1, w)
                    vy = table ! (i-1, w - wi)
                    (vi, wi) = objs ! i

toBS :: [[Int]] -> C.ByteString
toBS [[n], is] = C.intercalate "\n"
                 [C.pack (show n), C.intercalate " " $ C.pack . show <$> is]

However, the code gives TLE once submitted to Kattis, which seems surprising given its O(Cn) complexity (picking from n objects with maximum capacity C). Does anyone have any suggestions on how to fix this?

I have already tried using mutable Arrays in ST monad. But mutable Arrays doesn't help here, which is not surprising because the DP arrays never need to be updated.

Profiled it on C=2000, n=2000, with values and weights uniform randomly picked between 1 and 20000. It took ~1.16 seconds. Full profile attached below:

     523,035,224 bytes allocated in the heap
     598,289,064 bytes copied during GC
     144,045,528 bytes maximum residency (4 sample(s))
         662,056 bytes maximum slop
             254 MiB total memory in use (0 MB lost due to fragmentation)

                                     Tot time (elapsed)  Avg pause  Max pause
  Gen  0       464 colls,     0 par    0.374s   0.394s     0.0008s    0.0196s
  Gen  1         4 colls,     0 par    0.141s   0.202s     0.0505s    0.1067s

  INIT    time    0.000s  (  0.004s elapsed)
  MUT     time    0.651s  (  0.664s elapsed)
  GC      time    0.515s  (  0.596s elapsed)
  EXIT    time    0.000s  (  0.001s elapsed)
  Total   time    1.166s  (  1.265s elapsed)

  %GC     time       0.0%  (0.0% elapsed)

  Alloc rate    804,028,826 bytes per MUT second

  Productivity  55.8% of total user, 52.5% of total elapsed

Solution

  • Just managed to get an accepted variant using IOUArray and IOArray. Also need to tweak the code to reduce the use of Lists as much as possible. Accepted with 1.61 seconds.

    I tried STUArray/STArray early on and thought they would offer the same performance as IOUArray/IOArray. However, it turns out that even if STUArray/STArray based solution used less memory and time than functional Arrays, it still failed the last test file for TLE.

    I saw the fastest accepted Haskell solution used only 0.38 second. I am curious what they did to make it that fast. My accepted code is attached below along with some profile info. Any ideas to further improve its performance are welcome.

    {-# Language OverloadedStrings #-}
    
    import Control.Arrow ((>>>))
    import Control.Monad
    import Data.List (intercalate)
    import Data.Array.IO
    import Data.Maybe
    import qualified Data.ByteString.Lazy.Char8 as C
      
    main = do
      inStr <- C.getContents
      solveCases $ (C.words >>> fmap readInt) inStr
      where readInt = C.readInt >>> fromJust >>> fst
    
    solveCases :: [Int] -> IO ()
    solveCases [] = return ()
    solveCases (c:n:os) = do
      objs <- newArray (1, n) (0,0)
      os' <- fillObj objs n os
      table <- buildTable objs
      indices <- recover table objs n c []
      putStrLn $ show $ length indices
      putStrLn $ intercalate " " $ show <$> indices
      solveCases os'
      where 
        recover :: IOUArray (Int, Int) Int ->
                   IOArray Int (Int, Int) ->
                   Int -> Int -> [Int] -> IO [Int]
        recover table objs i j rs = do
          v <- readArray table (i, j)
          if (v == 0)
            then return rs
            else do
            v' <- readArray table (i-1, j)
            if (v == v')
              then recover table objs (i-1) j rs
              else do
              (_, wi) <- readArray objs i
              recover table objs (i-1) (j-wi) ((i-1):rs)
              
        fillObj :: IOArray Int (Int, Int) -> Int -> [Int] -> IO [Int]
        fillObj objs n vws = go 1 vws
          where go :: Int -> [Int] -> IO [Int]
                go i (v:w:vws) | i == n = writeArray objs i (v, w) >> return vws
                               | otherwise = do
                                   writeArray objs i (v, w)
                                   go (i+1) vws
    
        bnds = ((0,0), (n,c))
    
        -- table[i][j] is the max value that can be achieved with
        -- objects [0..i] such that the max selected weight is <= j.
        buildTable :: IOArray Int (Int, Int) -> IO (IOUArray (Int, Int) Int)
        buildTable objs = do
          table <- newArray bnds 0
          forM_ (range bnds) $ \(i, w) -> do
            when (i > 0 && w > 0) $ do
              (vi, wi) <- readArray objs i
              vx <- readArray table (i-1, w)
              if (w < wi)
                then do
                writeArray table (i, w) vx
                else do
                vy <- readArray table (i-1, w-wi)
                writeArray table (i, w) $ max vx (vy+vi)
          return table
    
          33,901,296 bytes allocated in the heap
             183,504 bytes copied during GC
              53,320 bytes maximum residency (1 sample(s))
              36,792 bytes maximum slop
                  33 MiB total memory in use (0 MB lost due to fragmentation)
    
                                         Tot time (elapsed)  Avg pause  Max pause
      Gen  0         2 colls,     0 par    0.000s   0.000s     0.0001s    0.0002s
      Gen  1         1 colls,     0 par    0.000s   0.003s     0.0025s    0.0025s
    
      INIT    time    0.000s  (  0.004s elapsed)
      MUT     time    0.040s  (  0.048s elapsed)
      GC      time    0.000s  (  0.003s elapsed)
      EXIT    time    0.000s  (  0.006s elapsed)
      Total   time    0.040s  (  0.061s elapsed)
    
      %GC     time       0.0%  (0.0% elapsed)
    
      Alloc rate    847,130,013 bytes per MUT second
    
      Productivity  98.9% of total user, 79.3% of total elapsed