Search code examples
ocaml

Can I simplify this recursive concat function using List.fold_left?


I have created a working solution for concat, but I feel that I can reduce this using List.fold_lift.

Here is my current code:

let rec concat (lists : 'a list list) : 'a list =
    match lists with
    | [] -> []
    | hd :: tl -> hd @ concat tl ;;

Here is what I have tried:

let concat (lists : 'a list list) : 'a list =
    List.fold_left @ lists ;;

This gives me the error: This expression has type 'a list list but an expression was expected of type 'a list

I think this is because the return value of list.fold_left gives us a list, but we are feeding it a list of lists so it then returns a list of lists again. How can I get around this without matching?

I was also playing around with List.map but with no luck so far:

let concat (lists : 'a list list) : 'a list =
    List.map (fun x -> List.fold_left @ x) lists ;;

Solution

  • Consider the type signature of List.fold_left:

    ('a -> 'b -> 'a) -> 'a -> 'b list -> 'a
    

    List.fold_left takes three arguments.

    1. A function.
    2. An initial value.
    3. A list to iterate over.
    List.fold_left @ lists
    

    You're making two mistakes.

    First off, this parses as (List.fold_left) @ (lists).

    You're looking for List.fold_left (@) lists. But that's still not quite right, because...

    You're only passing two arguments, with lists being the initial value, while List.fold_left expects three.

    I think that you're looking for something like:

    let concat lists = List.fold_left (@) [] lists
    

    Demonstrated:

    # let concat lists = List.fold_left (@) [] lists in
      concat [[1;2;3]; [4;5;6]; [7;8;9]];;
    - : int list = [1; 2; 3; 4; 5; 6; 7; 8; 9]
    

    Runtime complexity

    The danger to this approach to concatenating lists is that while it runs in constant stack space since List.fold_left is tail-recursive and @ is (at least as of this edit), it's runtime complexity is O(n2).

    concat [[1; 2; 3]; [4; 5; 6]; [7; 8; 9]
    

    Is equivalent to:

    (([] @ [1; 2; 3]) @ [4; 5; 6]) @ [7; 8; 9]
    

    This code has to iterate to the end of [1; 2; 3] to generate [1; 2; 3; 4; 5; 6] and then has to iterate to the end of [1; 2; 3; 4; 5; 6] to generate [1; 2; 3; 4; 5; 6; 7; 8; 9].

    Now imagine our list of lists had a very large number of lists and each was large. The runtime quickly gets out of control.

    Instead we can use ::, which has constant runtime to reduce this to O(n) runtime, as in the following code.

    let rec concat = function
      | [] -> []
      | []::tl -> concat tl
      | (x::xs)::tl -> x :: concat (xs :: tl)
    

    Of course, this is not tail-recursive, so it runs in linear stack space, but fortunately tail_mod_cons gives us an easy fix for this.

    let[@tail_mod_cons] rec concat = function
      | [] -> []
      | []::tl -> concat tl
      | (x::xs)::tl -> x :: concat (xs :: tl)