Search code examples
haskelltypeclass

Haskell instances under type variable conditions


Starting with a concrete instance of my question, we all know (and love) the Monad type class:

class ... => Monad m where
  return :: a -> m a
  (>>=)  :: m a -> (a -> m b) -> mb
  ...

Consider the following would-be instance, where we modify the standard list/"nondeterminism" instance using nub to retain only one copy of each "outcome":

type DistinctList a = DL { dL :: [a] }
instance Monad DistinctList where
  return = DL . return
  x >>= f = DL . nub $ (dL x) >>= (dL . f)

...Do you spot the error? The problem is that nub :: Eq a => [a] -> [a] and so x >>= f is only defined under the condition f :: Eq b => a -> DistinctList b, whereas the compiler demands f :: a -> DistinctList b. Is there some way I can proceed anyway?

Stepping back, suppose I have a would-be instance that is only defined under some condition on the parametric type's variable. I understand that this is generally not allowed because other code written with the type class cannot be guaranteed to supply parameter values that obey the condition. But are there circumstances where this still can be carried out? If so, how?


Solution

  • Here is an adaptation of the technique applied in set-monad to your case.

    Note there is, as there must be, some "cheating". The structure includes extra value constructors to represent "return" and "bind". These act as suspended computations that need to be run. The Eq instance is there part of the run function, while the constructors that create the "suspension" are Eq free.

    {-# LANGUAGE GADTs #-}
    
    import qualified Data.List            as L
    import qualified Data.Functor         as F
    import qualified Control.Applicative  as A
    import Control.Monad
    
    -- for reference, the bind operation to be implemented
    -- bind operation requires Eq
    dlbind :: Eq b => [a] -> (a -> [b]) -> [b] 
    dlbind xs f = L.nub $ xs >>= f
    
    -- data structure comes with incorporated return and bind 
    -- `Prim xs` wraps a list into a DL   
    data DL a where
      Prim   :: [a] -> DL a
      Return :: a -> DL a
      Bind   :: DL a -> (a -> DL b) -> DL b
    
    -- converts a DL to a list 
    run :: Eq a => DL a -> [a]
    run (Prim xs)             = xs
    run (Return x)            = [x]
    run (Bind (Prim xs) f)    = L.nub $ concatMap (run . f) xs
    run (Bind (Return x) f)   = run (f x)
    run (Bind (Bind ma f) g)  = run (Bind ma (\a -> Bind (f a) g))
    
    -- lifting of Eq and Show instance
    -- Note: you probably should provide a different instance
    --       one where eq doesn't depend on the position of the elements
    --       otherwise you break functor laws (and everything else)
    instance (Eq a) => Eq (DL a) where
      dxs == dys = run dxs == run dys
    
    -- this "cheats", i.e. it will convert to lists in order to show. 
    -- executing returns and binds in the process        
    instance (Show a, Eq a) => Show (DL a) where
      show = show . run
    
    -- uses the monad instance
    instance F.Functor DL where
      fmap  = liftM 
    
    -- uses the monad instance
    instance A.Applicative DL where
      pure  = return
      (<*>) = ap
    
    -- builds the DL using Return and Bind constructors
    instance Monad DL where
      return = Return
      (>>=)  = Bind
    
    -- examples with bind for a "normal list" and a "distinct list"
    list  =  [1,2,3,4] >>= (\x ->  [x `mod` 2, x `mod` 3])   
    dlist = (Prim [1,2,3,4]) >>= (\x -> Prim [x `mod` 2, x `mod` 3]) 
    

    And here is a dirty hack to make it more efficient, addressing the points raised below about evaluation of bind.

    {-# LANGUAGE GADTs #-}
    
    import qualified Data.List            as L
    import qualified Data.Set             as S
    import qualified Data.Functor         as F
    import qualified Control.Applicative  as A
    import Control.Monad
    
    
    dlbind xs f = L.nub $ xs >>= f
    
    data DL a where
      Prim   :: Eq a => [a] -> DL a
      Return :: a -> DL a
      Bind   :: DL b -> (b -> DL a) -> DL a
    --  Fail   :: DL a  -- could be add to clear failure chains
    
    run :: Eq a => DL a -> [a]
    run (Prim xs)      = xs
    run (Return x)     = [x]
    run b@(Bind _ _)   =
      case foldChain b of 
        (Bind (Prim xs) f)   -> L.nub $ concatMap (run . f) xs
        (Bind (Return a) f)  -> run (f a)
        (Bind (Bind ma f) g) -> run (Bind ma (\a -> Bind (f a) g))
    
    -- fold a chain ((( ... >>= f) >>= g) >>= h
    foldChain :: DL u -> DL u  
    foldChain (Bind b2 g) = stepChain $ Bind (foldChain b2) g 
    foldChain dxs         = dxs
    
    -- simplify (Prim _ >>= f) >>= g 
    --   if  (f x = Prim _)
    --   then reduce to (Prim _ >>= g)
    --   else preserve  (Prim _ >>= f) >>= g 
    stepChain :: DL u -> DL u
    stepChain b@(Bind (Bind (Prim xs) f) g) =
      let dys = map f xs
          pms = [Prim ys   | Prim   ys <- dys]
          ret = [Return ys | Return ys <- dys]
          bnd = [Bind ys f | Bind ys f <- dys]
      in case (pms, ret, bnd) of
           -- ([],[],[]) -> Fail -- could clear failure
           (dxs@(Prim ys:_),[],[]) -> let Prim xs = joinPrims dxs (Prim $ mkEmpty ys)
                                      in Bind (Prim $ L.nub xs) g       
           _  -> b
    stepChain dxs = dxs
    
    -- empty list with type via proxy  
    mkEmpty :: proxy a -> [a]
    mkEmpty proxy = []
    
    -- concatenate Prims in on Prim
    joinPrims [] dys = dys 
    joinPrims (Prim zs : dzs) dys = let Prim xs = joinPrims dzs dys in Prim (zs ++ xs)  
    
    instance (Ord a) => Eq (DL a) where
      dxs == dys = run dxs == run dys
    
    instance (Ord a) => Ord (DL a) where
      compare dxs dys = compare (run dxs) (run dys)
    
    instance (Show a, Eq a) => Show (DL a) where
      show = show . run    
    
    instance F.Functor DL where
      fmap  = liftM 
    
    instance A.Applicative DL where
      pure  = return
      (<*>) = ap
    
    instance Monad DL where
      return = Return
      (>>=)  = Bind
    
    
    -- cheating here, Prim is needed for efficiency 
    return' x = Prim [x]
    
    s =  [1,2,3,4] >>= (\x ->  [x `mod` 2, x `mod` 3])   
    t = (Prim [1,2,3,4]) >>= (\x -> Prim [x `mod` 2, x `mod` 3]) 
    r' = ((Prim [1..1000]) >>= (\x -> return' 1)) >>= (\x -> Prim [1..1000])