i'd like to define cpsRec
as follows but i couldn't.
please let me know if you come up with ideas for implementation.
import Control.Monad.Trans.Cont (Cont)
type family ContRec r x where
ContRec r (a -> b) = a -> ContRec r b
ContRec r a = Cont r a
cpsRec :: (a -> b) -> (a -> ContRec r b)
cpsRec f a =
let fa = f a
in case fa of
(x -> y) -> cpsRec fa -- error!
_ -> pure fa -- error!
-- use case
addT :: Int -> Int -> Int -> Int
addT x y z = x + y + z
addCpsT :: Int -> Int -> Int -> Cont r Int
addCpsT = cpsRec addT
Here is an example of implementation of cpsRec
which works for a function with any number of arguments:
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
import Control.Monad.Trans.Cont (Cont)
import Data.Proxy (Proxy(..))
-- | Helper type function to distinguish function from non-function
type family IsFun a where
IsFun (a -> b) = 'True
IsFun a = 'False
-- | Helper type class which includes auxiliary lifted Bool type parameter
class GContRec (i :: Bool) a rs where
gcpsRec :: Proxy i -> a -> rs
-- | Intermediate recursive case: for a function `a -> b` (when `IsFun == True`)
instance (GContRec (IsFun b) b rs', (a -> rs') ~ rs) => GContRec 'True (a -> b) rs where
gcpsRec _ f = gcpsRec (Proxy :: Proxy (IsFun b)) . f
-- | Base recursive case: not a function (`IsFun == False`) i.e. last argument - lift it to `Cont t a`
instance GContRec 'False a (Cont r a) where
gcpsRec _ = pure
-- | Type class which defines very "generic" `cpsRec` without auxiliary type parameter
class ContRec a rs where
cpsRec :: a -> rs
-- | Our implementation of `cpsRec` for `Cont`
instance (GContRec (IsFun a) a rs) => ContRec a rs where
cpsRec = gcpsRec (Proxy :: Proxy (IsFun a))
-- Works for functions with any number of arguments
notCpsT :: Bool -> Cont r Bool
notCpsT = cpsRec not
addT :: Int -> Int -> Int -> Int
addT x y z = x + y + z
addCpsT :: Int -> Int -> Int -> Cont r Int
addCpsT = cpsRec addT
foldrCpsT :: Int -> [Int] -> Cont r Int
foldrCpsT = cpsRec (foldr (+))