Search code examples
haskellsolversmtsatsbv

Sequence of states in Haskell SBV doesn't satisfy constraints


I have a symbolic enumeration like the following:

data State = Start | Dot
mkSymbolicEnumeration ''State

The function to evaluate whether a state is valid within a sequence, relative to the previous state, is defined so that sDot should only be preceded by sStart, and sStart should only be preceded by sDot — in theory, this means that we should never have two consecutive sStart or sDot in our sequence:

validSequence :: SList State -> SInteger -> SBool
validSequence seq i = case seq .!! i of
    sStart -> p1 .== sDot      -- sStart can only be preceded by sDot
    sDot   -> p1 .== sStart    -- sDot can only be preceded by sStart
    where p1 = seq .!! (i-1)

Then, two groups of constraints are declared. The first one states that seq should be of length n, and the second group states than every seq !! i with i /= 0 should satisfy validSequence:

-- sequence should be of length n
constrain $ L.length seq .== fromIntegral n

-- apply a validSequence constraint for every i in [1..n]
mapM_ (constrain . (validSequence seq) . fromIntegral) [1..n]

When I load this module into ghci, the result I get is different from what I would expect:

runSMT $ answer 10
-- expecting this: [Dot, Start, Dot, Start, Dot, Start, Dot, Start, Dot, Start]
-- or this:        [Start, Dot, Start, Dot, Start, Dot, Start, Dot, Start, Dot]
-- actual result:  [Dot, Dot, Dot, Dot, Dot, Dot, Dot, Dot, Dot, Dot]

What I don't understand:

  • why doesn't the actual result satisfy the constraint that Dot should only follow Start
  • In particular, am I doing something wrong in validSequence?
  • Or alternatively, am I using the mapM_ call in a wrong way?

Complete reproducible code is as follows (requires SBV library):

{-# LANGUAGE DeriveAnyClass      #-}
{-# LANGUAGE DeriveDataTypeable  #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving  #-}
{-# LANGUAGE TemplateHaskell     #-}

module Sandbox where

import           Data.SBV
import           Data.SBV.Control

import           Data.SBV.List    ((.!!))
import qualified Data.SBV.List    as L


data State = Start | Dot
mkSymbolicEnumeration ''State

validSequence :: SList State -> SInteger -> SBool
validSequence seq i = case seq .!! i of
    sStart -> p1 .== sDot      -- sStart can only be preceded by sDot
    sDot   -> p1 .== sStart    -- sDot can only be preceded by sStart
    where p1 = seq .!! (i-1)


answer :: Int -> Symbolic [State]
answer n = do
    seq <- sList "seq"

    -- sequence should be of length n
    constrain $ L.length seq .== fromIntegral n

    -- apply a validSequence constraint for every i in [1..n]
    mapM_ (constrain . (validSequence seq) . fromIntegral) [1..n]

    query $ do cs <- checkSat
               case cs of
                    Unk    -> error "Solver returned unknown!"
                    DSat{} -> error "Unexpected dsat result!"
                    Unsat  -> error "Solver couldn't find a satisfiable solution"
                    Sat    -> getValue seq

Solution

  • validSequence :: SList State -> SInteger -> SBool
    validSequence seq i = case seq .!! i of
        sStart -> p1 .== sDot      -- sStart can only be preceded by sDot
        sDot   -> p1 .== sStart    -- sDot can only be preceded by sStart
        where p1 = seq .!! (i-1)
    

    is equivalent to

    validSequence :: SList State -> SInteger -> SBool
    validSequence seq i = case seq .!! i of
        _  -> p1 .== sDot
        where p1 = seq .!! (i-1)
    

    since sStart is the name if a fresh, local variable which has no relation to any other variable with the same name. Turning on warnings in GHC should report this name shadowing.

    I can't suggest how to fix this because I am unfamiliar with SBV. In particular, I can't see if the check (seq .!! i) == sStart you are trying to make can be done at the Haskell level or must instead be performed at the SBV level, so that it generates the right formula to be passed to the SMT solver.

    Maybe you need something like (pseudo code):

    validSequence seq i = 
        (p2 .== sStart .&& p1 .== sDot) .||
        (p1 .== sStart .&& p2 .== sDot)
       where p1 = seq .!! (i-1)
             p2 = seq .!! i
    

    EDIT: actual working implementation based on the above pseudo code, but following SBV's DSL:

    validSequence :: SList State -> SInteger -> SBool
    validSequence seq i =
          ite (cur .== sStart) (prev `sElem` [sDot])
        $ ite (cur .== sDot)   (prev `sElem` [sStart])
          sFalse
        where cur  = seq .!! i
              prev = seq .!! (i-1)