I extracted the following minimal example from my production project. My machine learning project is made up of a linear algebra library, a deep learning library, and an application.
The linear algebra library contains a module for matrices based on storable vectors:
module Matrix where
import Data.Vector.Storable hiding (sum)
data Matrix a = Matrix { rows :: Int, cols :: Int, items :: Vector a } deriving (Eq, Show, Read)
item :: Storable a => Int -> Int -> Matrix a -> a
item i j m = unsafeIndex (items m) $ i * cols m + j
multiply :: Storable a => Num a => Matrix a -> Matrix a -> Matrix a
multiply a b = Matrix (rows a) (cols b) $ generate (rows a * cols b) (f . flip divMod (cols b)) where
f (i, j) = sum $ (\ k -> item i k a * item k j b) <$> [0 .. cols a - 1]
The deep learning library uses the linear algebra library to implement the forward pass through a deep neural network:
module Deep where
import Foreign.Storable
import Matrix
transform :: Storable a => Num a => [Matrix a] -> Matrix a -> Matrix a
transform layers batch = foldr multiply batch layers
And finally the application uses the deep learning library:
import qualified Data.Vector.Storable as VS
import Test.Tasty.Bench
import Matrix
import Deep
main :: IO ()
main = defaultMain [bmultiply] where
bmultiply = bench "bmultiply" $ nf (items . transform layers) batch where
m k l c = Matrix k l $ VS.replicate (k * l) c :: Matrix Double
layers = m 256 256 <$> [0.1, 0.2, 0.3]
batch = m 256 100 0.4
I like the fact that the deep learning library and with some exceptions related to BLAS via FFI also the linear algebra library do not have to worry about concrete types like Float
or Double
. Unfortunately, this also means that unless specialization takes place, they use boxed values and performance is about 60x worse than it could be (959 ms instead of 16.7 ms).
The only way I have found to get good performance is to force either inlining or specialization throughout the entire call hierarchy via compiler pragmas. This is very annoying because the performance issue that fundamentally should be specific to the multiply
function now "infects" the entire code base. Even very high-level functions using multiply
via 5 levels of indirection and several intermediate libraries somehow have to "know" about technical specialization issues deep down.
In my actual production code, many more functions are affected than in this minimal example. Forgetting to annotate just a single one of these functions with the right compiler pragma immediately destroys the performance. Additionally, when developing a library, I have no way of knowing which types it will be used with, so specialization pragmas are not an option anyways.
This is particularly unfortunate because all the performance-critical tight loops are wholly contained within the multiply
function. The function itself is only called a handful of times and it would not hurt performance if values were only unboxed dynamically whenever multiply
is called. In the end, there is really no need for values to be specialized and unboxed inside the high-level machine learning functions. I feel like there should be a way to pass the request for specialization through to the low-level functions while keeping high- and intermediate-level functions polymorphic.
How is this problem typically solved in Haskell? If I develop a library that uses the vector package to generate blazingly-fast code in tight loops, how do I pass that performance on to users of my library without losing all polymorphism or forcing everything to be inlined?
Is there a way to pay the price for polymorphism (in the form of boxing) only within the high-level functions and specialize and unbox only at the boundary to the functions that need it, rather than having specialization "infect" the entire call hierarchy?
If you browse the source for, say, the vector
package, you'll find that nearly every function has an INLINABLE
or INLINE
pragma, whether the function is part of the low-level, performance critical core or part of a high-level generic interface. You'll see something similar if you look at lens
or hmatrix
, etc.
So, the short answer is: no, the only way to get good performance with your current design will be to infect the entire call hierarchy with pragmas. The best way to avoid missing a pragma and tanking performance will be to have an exhaustive set of benchmarks that can detect performance regressions.
There are a few compiler flags that might be helpful. The flag -fexpose-all-unfoldings
makes sure that inlinable versions of all functions find their way into the interface files, while the flag -fspecialise-aggressively
looks for any opportunity to specialize those functions. Together, they are kind of like turning on INLINE
for every function. This probably isn't a good permanent solution, but it might be useful during development or as a sanity check to get some baseline performance numbers.