Search code examples
haskellsamplinggeneralization

How can I generalize my sampling framework?


In the context of a stochastic ray tracer, I'd like to decouple the MC integration (path tracing, bidirectional path tracing) from sample generation (uniform random, stratified, poisson, metropolis, ...). Most of this is already implemented, but it's tedious to use. So I ditched that and try build something nicer, by splitting sampled computations in two phases: In SampleGen you are allowed to request a random value using the mk1d and mk2d functions, which are then supplied with actual Floats by the sampling algorithm. Those values can be examined in SampleRun to do the actual computation. Here's some code with the interesting bits of a stratified sampler and it's use:

{-# LANGUAGE GeneralizedNewtypeDeriving #-}

import Control.Applicative
import Control.Monad.State.Strict
import Control.Monad.Primitive
import System.Random.MWC as MWC

-- allows to construct sampled computations
newtype SampleGen s m a = SampleGen (StateT s m a)
                       deriving ( Functor, Applicative, Monad
                                , MonadState s, MonadTrans )

-- allows to evaluate sampled computations constructed in SampleGen
newtype SampleRun s m a = SampleRun (StateT s m a)
                       deriving ( Functor, Applicative, Monad
                                , MonadState s )

-- a sampled computation, parametrized over the generator's state g,
-- the evaluator's state r,  the underlying monad m and the result
-- type a
type Sampled g r m a = SampleGen g m (SampleRun r m a)

----------------------
-- Stratified Sampling
----------------------

-- | we just count the number of requested 1D samples
type StratGen = Int

-- | the pre-computed values and a RNG for additional ones
type StratRun m = ([Float], Gen (PrimState m))

-- | specialization of Sampled for stratified sampling
type Stratified m a = Sampled StratGen (StratRun m) m a

-- | gives a sampled value in [0..1), this is kind
--   of the "prime" value, upon which all computations
--   are built
mk1d :: PrimMonad m => Stratified m Float
mk1d = do
  n1d <- get
  put $ n1d + 1

  return $ SampleRun $ do
    fs <- gets fst
    if length fs > n1d
      then return (fs !! n1d)
      else gets snd >>= lift . MWC.uniform

-- | gives a pair of stratified values, should really also
--   be a "prime" value, but here we just construct them
--   from two 1D samples for fun
mk2d :: (Functor m, PrimMonad m) => Stratified m (Float, Float)
mk2d = mk1d >>= \f1 -> mk1d >>= \f2 ->
  return $ (,) <$> f1 <*> f2

-- | evaluates a stratified computation
runStratified
  :: (PrimMonad m)
  => Int            -- ^ number of samples
  -> Stratified m a -- ^ computation to evaluate
  -> m [a]          -- ^ the values produced, a list of nsamples values
runStratified nsamples (SampleGen c) = do
  (SampleRun x, n1d) <- runStateT c 0
  -- let's just pretend I'd use n1d to actually
  -- compute stratified samples
  gen <- MWC.create
  replicateM nsamples $ evalStateT x ([{- samples would go here #-}], gen)

-- estimate Pi by Monte Carlo sampling
-- mcPi :: (Functor m, PrimMonad m) => Sampled g r m Float
mcPi :: (Functor m, PrimMonad m) => Stratified m Float
mcPi = do
  v <- mk2d
  return $ v >>= \(x, y) -> return $ if x * x + y * y < 1 then 4 else 0

main :: IO ()
main = do
  vs <- runStratified 10000 mcPi :: IO [Float]
  print $ sum vs / fromIntegral (length vs)

The missing part here is that in it's current form, the mcPi function has the type

mcPi :: (Functor m, PrimMonad m) => Stratified m Float

while it should really be something like

mcPi :: (Functor m, PrimMonad m) => Sampled g r m Float

Admitted, the four type parameters on Sampled aren't exactly beautiful, but at least something like this would be useful. In summary, I'm looking for something allowing to express computations like mcPi independent of the sampling algorithm, e.g.:

  • a uniform random sampler does not need to maintain any state in the SampleGen phase, and needs only a RNG in the SampleRun phase
  • both, the stratified and the poisson disk sampler (and probably others) keep track of the number of 1D and 2D samples needed and precompute them into a vector, and they would be allowed to share a SampleGen and SampleRun implementation, to differ only in what happens inbetween SampleGen and SampleRun (how the vector is actually filled)
  • a metropolis sampler would use a lazy sample generation technique in it's SampleRun phase

I'd like to compile it using GHC, so extensions like MultiParamTypeClasses and TypeFamilies are ok to me, but I did not come up with anything remotely usable.

PS: As motivation, some pretty pictures. And the code in it's current form is on GitHub


Solution

  • I'm going to start off with a radically different question, "What should the code look like"?, and then work towards the question "How is the sampling framework put together"?.

    What the code should look like

    The definition of mcPi should be

    mcPi :: (Num s, Num p) => s -> s -> p
    mcPi x y = if x * x + y * y < 1 then 4 else 0
    

    The Monte Carlo estimation of pi is that, given two numbers (that happen to come from the interval [0..1)) pi is the area of a square if they fall within a circle, otherwise it's 0. The Monte Carlo estimation of pi doesn't know anything about computation. It doesn't know if it's going to be repeated, or anything about where the numbers came from. It does know that the numbers should be uniformly distributed over the square, but that's a topic for a different question. The Monte Carlo estimation of pi is just a function from the samples to the estimate.

    Other random things will know that they are part of a random process. A simple random process might be: flip a coin, if the coin comes up "heads", flip it again.

    simpleRandomProcess :: (Monad m, MonadCoinFlip m) => m Coin
    simpleRandomProcess =
        do
            firstFlip <- flipACoin
            case firstFlip of 
                Heads -> flipACoin
                Tails -> firstFlip
    

    This random process would like to be able to see things like

    data Coin = Heads | Tails
    
    class MonadCoinFlip m where
        flipACoin :: m Coin -- The coin should be fair
    

    Random processes may change how much random data they need based on the results of previous experiments. This suggests that we will ultimately need to provide a Monad.

    The interface

    You would like to "decouple the MC integration (path tracing, bidirectional path tracing) from sample generation (uniform random, stratified, poisson, metropolis, ...)". In your examples, they all want to sample floats. That suggests the following class

    class MonadSample m where
        sample :: m Float -- Should be on the interval [0..1)
    

    This is very similar to the existing MonadRandom class, except for two things. A MonadRandom implementation essentially needs to provide a uniformly random Int in some range of its own choosing. Your sampler will provide a Float sample of unknown distribution on the interval [0..1). This is different enough to justify having your own new class.

    Due to the upcoming Monad Applicative change, I'm instead going to suggest a different name for this class, SampleSource.

    class SampleSource f where
        sample :: f Float -- Should be on the interval [0..1)
    

    sample replaces mk1d in your code. mk2d can also be replaced, again not knowing what the source of the samples will be. sample2d, the replacement for mk2d, will work with any Applicative sample source, it doesn't need it to be a Monad. The reason it doesn't need a Monad is it won't decide how many samples to get, or what else to do, based on the result of samples; the structure of its computation is known ahead of time.

    sample2d :: (Applicative f, SampleSource f) => f (Float, Float)
    sample2d = (,) <$> sample <*> sample
    

    If you are going to allow the sample source to introduce interactions between dimensions, for example for Poisson disk sampling, you'd need to add that to the interface, either explicitly enumerating the dimensions

    class SampleSource f where
        sample   :: f Float
        sample2d :: f (Float, Float)
        sample3d :: f (Float, Float, Float)
        sample4d :: f (Float, Float, Float, Float)
    

    or using some vector library.

    class SampleSource f where
        sample  :: f Float
        samples :: Int -> f (Vector Float)
    

    Implementing the interface

    Now, we need to describe how each of your sample sources can be used as a SampleSource. As an example, I'll implement SampleSource for one of the worst sample sources there is.

    newtype ZeroSampleSourceT m a = ZeroSampleSourceT {
        unZeroSampleSourceT :: IdentityT m a
    } deriving (MonadTrans, Monad, Functor, MonadPlus, Applicative, Alternative, MonadIO)
    
    instance (Monad m) => SampleSource (ZeroSampleSourceT m a) where
        sample = return 0
    
    runZeroSampleSourceT :: (Monad m) => ZeroSampleSourceT m a -> m a
    runZeroSampleSourceT = runIdentityT . unZeroSampleSourceT
    

    When all Monads are Applicative I'd instead write

    instance (Applicative f) => SampleSource (ZeroSampleSourceT f) where
        sample = pure 0
    

    I'll also implement an MWC uniform SampleSource.

    newtype MWCUniformSampleSourceT m a = MWCUniformSampleSourceT m a {
        unMWCUniformSampleSourceT :: ReaderT (Gen (PrimState m)) m a
    } deriving (MonadTrans, Monad, Functor, MonadPlus, Applicative, Alternative, MonadIO)
    
    runMWCUniformSampleSourceT :: MWCUniformSampleSourceT m a -> (Gen (PrimState m)) -> m a
    runMWCUniformSampleSourceT = runReaderT . unMWCUniformSampleSourceT
    
    -- MWC's uniform generates floats on the open-closed interval (0,1]
    uniformClosedOpen :: PrimMonad m => Gen (PrimState m) -> m Float
    uniformClosedOpen = fmap (\x -> x - 2**(-33)) . uniform
    
    instance (PrimMonad m) => SampleSource (MWCUniformSampleSourceT m) where
        sample = MWCUniformSampleSourceT . ReaderT $ uniformClosedOpen
    

    We won't completely implement Stratified or runStratified, since your example code doesn't contain complete implementations for them.

    But I want to know how many samples will be used ahead of time

    I'm not sure exactly what you are trying to do with "stratified" sampling. Pre-generating numbers, and using a generator when those run out isn't what I understand stratified sampling to be. If you are going to provide a monadic interface to something, you won't be able to tell ahead of time what will be executed, so you won't be able to predict how many samples a computation will need before you start executing it. If you can settle for only an Applicative interface, then you can test ahead of time how many samples will be needed by the entire computation.

    But Poisson Disk sampling needs to know how many points are being sampled ahead of time

    If a single sampling can depend on both the number of samples needed and the number of dimensions, like in Poisson Disk sampling, those need to be passed to the sampler when they become known.

    class SampleSource f where
        sample   :: f Float
        samples  :: Int -> f ([Float])
        sampleN  :: Int -> f (Vector Float)
        samplesN :: Int -> Int -> f ([Vector Float])
    

    You could generalize this to sampling in arbitrary shapes in arbitrary dimensions, which is what we'd need to do if we took the next leap.

    Applicative query language with a Monadic interpreter

    We can go, very, very elaborate and make an Applicative query language for requests for samples. The language will need to add two features on top of what Applicative already does. It will need to be able to repeat requests and it will need to group requests for samples together to identify which groupings are meaningful. It's motivated by the following code, which wants to get 6 different 2d samples, where sample2d is the same as our first definition.

    take 6 (repeat sample2d)
    

    First, we'll need to be able to repeat things over and over. The nicest way to this would be if we could write, e.g.

    take 6 (repeat sample) :: SampleSource f => [f Float]
    

    We'd need a way to go from an [f a] to f [a]. This already exists; it's Data.Traversable's sequenceA, which requires that f be Applicative. So we already get repetition from Applicative.

    sequenceA . take 6 . repeat $ sample2d
    

    To group requests together, we'll add a function to mark which groupings are meaningful.

    sequenceA . take 6 . repeat . mark $ sample2d
    

    and a class for things that can mark some grouping. If we need more meaning than just groupings - for example if the internal things should be dependent or independent, we'd add it here.

    class Mark f where
        mark :: f a -> f a
    

    If everything is going to be very homogeneous, we might add a class for query-able sample sources

    class (Applicative f, Mark f, SampleSource f) => QueryableSampleSouce f where
    

    Now we will talk about the idea of a monad that has a more-optimized query language. Here we will start using all of those GHC-specific extensions; specifically TypeFamilies.

    class MonadQuery m where
        type Query m :: * -> *
        interpret :: (Query m a) -> m a
    

    And finally a class for monad sample sources with an Applicative query language

    class (MonadQuery m, QueryableSampleSource (Query m), SampleSource m, Monad m) => MonadSample m where
    

    At this point, we will want to work out what laws these should follow. I'd suggest a few:

    interpret sample == sample
    interpret (sequenceA a) = sequence (interpret a)
    

    That is, without a mark, sample sources don't get to do anything terribly special with the queries. This would mean that a query that wants to be subject to Poisson disk's special treatment of 2d points and special treatment of the set of points would need to be marked twice:

     mark . sequenceA . take 6 . repeat . mark $ sample2d
    

    The Applicative query language sort-of corresponds with your StratGen type; by having a mearly Applicative interface it allows you to look ahead at the structure of the incoming query. The Monad then corresponds with your StratRun type.