Search code examples
smlcontinuation

How to check recursive call results in CPS codes


So I'm working on a function to find some valid arithmetic operations to a target number from an int list. It's not allowed to use throw/callac. Only add and mul are valid arithmetic operations here and they are left associative.

datatype operation = ADD | MULT

(* find_op: int -> int list -> (operatino list -> 'a) -> (unit -> 'a) -> 'a *)
fun find_op x [] s k = k()
  | find_op x [y] s k = if x=y then s([]) else k()
  | find_op x (y1::y2::ys) s k =
    let
      val add = find_op x ((y1+y2)::ys) (fn a => s(ADD::a)) k
      val mul = find_op x ((y1*y2)::ys) (fn a => s(MULT::a)) k
    in
      need some work here
    end

The function should work like below:

Given list [1,1,2,~1] and target number ~4, the accpeted operation list should be [ADD,ADD,MULT] or [ADD,MULT,MULT], because (((1+1)+2)*~1) = ((1+1)2~1) = ~4. But [MULT,ADD,MULT] will not be valid since (((1*1)+2)*~1) = ~3.

I'm confused how to check whether returned results are k(). Using = to check return value is not possible since it is polymorphic. Is there any method to handle this?


Solution

  • What you have to do is use the two strategies, first try reducing the numbers via ADD, then reduce the numbers via MULT, but sequentially. In order to do this you need to provide a custom failure continuation (k) to the result of the first chosen strategy. If that strategy fails, you try the second strategy in the continuation failure.

    You can't try both strategies at the same time and have them both succeed. The function type does not permit returning multiple correct answers. For that you'd need the success continuation's type to be operation list list.

    datatype operation = ADD | MULT
    
    fun opToString ADD  = "ADD"
      | opToString MULT = "MULT"
    
    (* find_op: int -> int list -> (operation list -> 'a) -> (unit -> 'a) -> 'a *)
    fun find_op x [] s k               = k ()
      | find_op x [y] s k              = if x = y then s [] else k ()
      | find_op x (y1 :: y2 :: ys) s k =
        let
          (* You need a custom failure continuation that tries the MULT variant
           * if the ADD one fails.
           *)
          fun whenAddFails () =
            find_op x ((y1 * y2) :: ys) (fn a => s (MULT :: a)) k
          val add =
            find_op x ((y1 + y2) :: ys) (fn a => s (ADD :: a)) whenAddFails
        in
          add
        end
    
    
    fun test () =
      let
        val opList = [1,1,2,~1]
        val target = ~4
        fun success ops =
          "success: " ^ (String.concatWith " " (List.map opToString ops))
        fun failure () =
          "couldn't parse numbers as an operation list"
      in
        find_op target opList success failure
      end