Search code examples
recursionfunctional-programmingocaml

trace a nested recursion in Ocaml


I am trying to understand deeply nested recursion in OCaml by using the sorting list algorithm. For this reason I am tracing the below code which has a recursive function sort and calls another function insert.


let rec sort (lst : int list) =
  match lst with [] -> [] | head :: tail -> insert head (sort tail)

and insert elt lst =
  match lst with
  | [] -> [ elt ]
  | head :: tail -> if elt <= head then elt :: lst else head :: insert elt tail

I understand the first recursive calls for sort, but after that I cannot follow.

For instance, suppose we have the list [6, 2, 5, 3]. After sorting the tail of this list as 2,3,5 where in the code the head 6 is compared to each element of this tail? Can somebody provide a hint for the trace results?

utop # sort [6; 2; 5; 3];;        
> sort <-- [6; 2; 5; 3]                                                 
> sort <-- [2; 5; 3]                                                    
> sort <-- [5; 3]                                                       
> sort <-- [3]                                                          
> sort <-- []                                                           
> sort --> []                                                           
> insert <-- 3                                                          
> insert -->                                                            
> insert* <-- []                                                        
> insert* --> [3]                                                       
> sort --> [3]                                                          
> insert <-- 5                                                          
> insert -->                                                            
> insert* <-- [3]                                                       
> insert <-- 5                                                          
> insert -->                                                            
> insert* <-- []                                                        
> insert* --> [5]                                                       
> insert* --> [3; 5]                                                    
> sort --> [3; 5]                                                       
> insert <-- 2                                                          
> insert -->                                                            
> insert* <-- [3; 5]                                                    
> insert* --> [2; 3; 5]                                                 
> sort --> [2; 3; 5]                                                    
> insert <-- 6                                                          
> insert -->                                                            
> insert* <-- [2; 3; 5]                                                 
> insert <-- 6                                                          
> insert -->                                                            
> insert* <-- [3; 5]                                                    
> insert <-- 6                                                          
> insert -->                                                            
> insert* <-- [5]                                                       
> insert <-- 6                                                          
> insert -->                                                            
> insert* <-- []                                                        
> insert* --> [6]                                                       
> insert* --> [5; 6]                                                    
> insert* --> [3; 5; 6]                                                 
> insert* --> [2; 3; 5; 6]                                              
> sort --> [2; 3; 5; 6]                                                 
> 
> - : int list = [2; 3; 5; 6]**

Solution

  • First of all, there's no reason to have insert and sort being mutually recursive since insert doesn't depend on sort. So you could write it like this:

    let rec insert elt lst =
      match lst with
      | [] -> [ elt ]
      | head :: tail -> if elt <= head then elt :: lst else head :: insert elt tail
    
    let rec sort (lst : int list) =
      match lst with [] -> [] | head :: tail -> insert head (sort tail)
    

    Now, what happens in insert? The function tries to insert an element elt in a sorted list with the invariant that all elements before it should be smaller and all the elements after should be higher.

    Two cases happen:

    • if the list is empty, the invariant is ensured if you just return a list containing the element you were trying to insert.
    • if the list is not, it's composed of an element we'll call head and the rest of the list that we'll call tail. Now we have two new cases:
      • if elt <= head then all the elements of the list are higher than elt so you just return elt :: list (for example if you call insert 1 [2; 3; 4] you'll return [1; 2; 3; 4]
      • otherwise, head < elt so we need to add head in front of the list that will be returned by inserting elt to tail, hence the recursive call to insert elt tail

    Now, when you call sort you call it like this:

    insert head (sort tail)
    

    Why so? Because the invariant only works if the list you're trying to insert head into is sorted (hence the bold sorted before). So you need to sort tail before inserting head into it.

    If you have the following list: [3; 2; 1], you'll call

    insert 3 (sort [2; 1])

    which is transformed in

    insert 3 (insert 2 (sort [1]))

    which is transformed in

    insert 3 (insert 2 (insert 1 (sort [])))

    which is resolved in

    insert 3 (insert 2 [1])

    which is resolved in

    insert 3 [1; 2]

    which is resolved in

    [1; 2; 3]

    And your list is sorted.


    [EDIT]

    Here's the code with some printing to see what's happening:

    let pp_sep ppf () = Format.fprintf ppf "; "
    
    let rec insert elt lst =
      Format.printf "@[<v 2>(Insert %d in [%a]" elt
        Format.(pp_print_list ~pp_sep (fun ppf d -> fprintf ppf "%d" d))
        lst;
      let l =
        match lst with
        | [] -> [ elt ]
        | head :: tail ->
            if elt <= head then elt :: lst
            else (
              Format.printf "@,";
              head :: insert elt tail)
      in
      Format.printf ")@]";
      l
    
    let rec sort (lst : int list) =
      match lst with
      | [] -> []
      | head :: tail ->
          Format.printf "@[<v 2>(Sort [%a] then insert %d@,"
            Format.(pp_print_list ~pp_sep (fun ppf d -> fprintf ppf "%d" d))
            tail head;
          let l = insert head (sort tail) in
          Format.printf ")@]@,";
          l
    
    # sort [3;2;1];;
    (Sort [2; 1] then insert 3
      (Sort [1] then insert 2
        (Sort [] then insert 1
          (Insert 1 in []))
        (Insert 2 in [1]
          (Insert 2 in [])))
      (Insert 3 in [1; 2]
        (Insert 3 in [2]
          (Insert 3 in []))))
    - : int list = [1; 2; 3]