Search code examples
f#monadscomputation-expression

F# passing state into a function in Bind


This is an evolution of the question I asked here (F# Binding to output while carrying the state).

I am trying to get the Bind method to take a function with this signature PlanAccumulator<'a> -> PlanAccumulator<'b>. The getFood function is an example of this. I am wanting the Bind method to call getFood with the PlanAccumulator<'a> object.

The challenge is getting the Computation Expression (CE) to call the getFood function inside of the Bind method using the PlanAccumulator that exists at that point in the CE. I'm not sure how to do this but I feel like it should be possible.

type StepId = StepId of int
type State = {
  LastStepId : StepId
}

type Food =
    | Chicken
    | Rice

type Step =
  | GetFood of StepId * Food
  | Eat of StepId * Food
  | Sleep of StepId * duration:int

type PlanAccumulator<'T> = PlanAccumulator of State * Step list * 'T

let rng = System.Random(123)

let getFood (PlanAccumulator (s, p, r)) =
  printfn "GetFood"
  let randomFood = 
    if rng.NextDouble() > 0.5 then Food.Chicken
    else Food.Rice
  let (StepId lastStepId) = s.LastStepId
  let nextStepId = StepId (lastStepId + 1)
  let newState = { s with LastStepId = nextStepId }
  let newStep = GetFood (nextStepId, randomFood)
  PlanAccumulator (newState, newStep::p, randomFood)

type PlanBuilder (state: State) =

    member this.For (PlanAccumulator (state, steps1, res):PlanAccumulator<'T>, f:'T -> PlanAccumulator<'R>) : PlanAccumulator<'R> =
      printfn "For"
      let (PlanAccumulator(state2, steps2, res2)) = f res
      PlanAccumulator (state2, steps2 @ steps1, res2)

    member this.Bind (input:PlanAccumulator<'a> -> PlanAccumulator<'T>, f:'T -> PlanAccumulator<'R>) : PlanAccumulator<'R> =
        printfn "Bind"
        // THIS IS THE PROBLEM: How do I get the previous PlanAccumulator to 
        // this point in the computation?
        let PlanAccumulator (state1, steps1, res) = input previousAccumulator 
        let (PlanAccumulator(state2, steps2, res2)) = f (state1 res)
        PlanAccumulator (state2, steps2 @ steps1, res2)

    member this.Yield x = 
        printfn "Yield"
        PlanAccumulator (state, [], x)

    member this.Run (PlanAccumulator (s, p, r)) = 
        printfn "Run"
        s, List.rev p

    [<CustomOperation("eat", MaintainsVariableSpace=true)>]
    member this.Eat (PlanAccumulator(s, p, r), [<ProjectionParameter>] food) =
        printfn $"Eat: {food}"
        let (StepId lastStepId) = s.LastStepId
        let nextStepId = StepId (lastStepId + 1)
        let newState = { s with LastStepId = nextStepId }
        let newStep = Eat (nextStepId, (food r))
        PlanAccumulator (newState, newStep::p, r)

    [<CustomOperation("sleep", MaintainsVariableSpace=true)>]
    member this.Sleep (PlanAccumulator (s, p, r), [<ProjectionParameter>] duration) =
        printfn $"Sleep: {duration}"
        let (StepId lastStepId) = s.LastStepId
        let nextStepId = StepId (lastStepId + 1)
        let newState = { s with LastStepId = nextStepId }
        let newStep = Sleep (nextStepId, (duration r))
        PlanAccumulator (newState, newStep::p, r)

// let plan = PlanBuilder()
let initialState = {
  LastStepId = StepId 0
}

let newState, testPlan =
  PlanBuilder initialState {
      let! food = getFood
      sleep 5
      eat Chicken
  }

Here's an example of what testPlan would be if this worked as desired:

val testPlan : Step list =
    [
        (StepId 1, GetFood Chicken)
        (StepId 2, Sleep 1)
        (StepId 3, Eat Chicken)
    ]

Solution

  • I think you want a plain old state monad, which you can see here. Using this, I hacked your code as follows:

    type State<'s, 'a> = State of ('s -> ('a * 's))
    
    module State =
        let inline run state x = let (State(f)) = x in f state
        let get = State(fun s -> s, s)
        let put newState = State(fun _ -> (), newState)
        let map f s = State(fun (state: 's) ->
            let x, state = run state s
            f x, state)
    
    /// The state monad passes around an explicit internal state that can be
    /// updated along the way. It enables the appearance of mutability in a purely
    /// functional context by hiding away the state when used with its proper operators
    /// (in StateBuilder()). In other words, you implicitly pass around an implicit
    /// state that gets transformed along its journey through pipelined code.
    type StateBuilder() =
        member this.Zero () = State(fun s -> (), s)
        member this.Return x = State(fun s -> x, s)
        member inline this.ReturnFrom (x: State<'s, 'a>) = x
        member this.Bind (x, f) : State<'s, 'b> =
            State(fun state ->
                let (result: 'a), state = State.run state x
                State.run state (f result))
        member this.Combine (x1: State<'s, 'a>, x2: State<'s, 'b>) =
            State(fun state ->
                let result, state = State.run state x1
                State.run state x2)
        member this.Delay f : State<'s, 'a> = f ()
        member this.For (seq, (f: 'a -> State<'s, 'b>)) =
            seq
            |> Seq.map f
            |> Seq.reduceBack (fun x1 x2 -> this.Combine (x1, x2))
        member this.While (f, x) =
            if f () then this.Combine (x, this.While (f, x))
            else this.Zero ()
    
    let state = new StateBuilder()
    
    type StepId = StepId of int
    
    type PlanState = {
        LastStepId : StepId
    }
    
    type Food =
        | Chicken
        | Rice
    
    type Step =
        | GetFood of StepId * Food
        | Eat of StepId * Food
        | Sleep of StepId * duration:int
    
    type PlanAccumulator = PlanAccumulator of PlanState * Step list
    
    let rng = System.Random(123)
    
    let getFood =
        state {
            printfn "GetFood"
            let randomFood = 
                if rng.NextDouble() > 0.5 then Food.Chicken
                else Food.Rice
            let! (PlanAccumulator (planState, steps)) = State.get
            let (StepId lastStepId) = planState.LastStepId
            let nextStepId = StepId (lastStepId + 1)
            let newState = { planState with LastStepId = nextStepId }
            let newStep = GetFood (nextStepId, randomFood)
            do! State.put (PlanAccumulator (newState, newStep :: steps))
            return randomFood
        }
    
    let eat food =
        state {
            printfn "Eat: %A" food
            let! (PlanAccumulator (planState, steps)) = State.get
            let (StepId lastStepId) = planState.LastStepId
            let nextStepId = StepId (lastStepId + 1)
            let newState = { planState with LastStepId = nextStepId }
            let newStep = Eat (nextStepId, food)
            do! State.put (PlanAccumulator (newState, newStep :: steps))
        }
    
    let sleep duration =
        state {
            printfn "Sleep: %A" duration
            let! (PlanAccumulator (planState, steps)) = State.get
            let (StepId lastStepId) = planState.LastStepId
            let nextStepId = StepId (lastStepId + 1)
            let newState = { planState with LastStepId = nextStepId }
            let newStep = Sleep (nextStepId, duration)
            do! State.put (PlanAccumulator (newState, newStep :: steps))
        }
    
    let initialState = {
        LastStepId = StepId 0
    }
    
    let initialPlan =
        PlanAccumulator (initialState, List.empty)
    
    [<EntryPoint>]
    let main argv =
    
        let _, testPlan =
            state {
                let! food = getFood
                do! sleep 10
                do! eat food
            } |> State.run initialPlan
    
        printfn "%A" testPlan
    
        0
    

    Output is:

    GetFood
    Sleep: 10
    Eat: Chicken
    PlanAccumulator
      ({ LastStepId = StepId 3 },
       [Eat (StepId 3, Chicken); Sleep (StepId 2, 10); GetFood (StepId 1, Chicken)])