Search code examples
sortinghaskellrecursioninsertion

Recursive insertion sort function in Haskell


I have working code for a recursive insertion sort function. I just need clarification on how it works. Right now, I'm confused what's happening on the third line where it says insert(isort xs). Please explain the rest of the code, step by step if you can. Thanks!

isort :: [Int] -> [Int]
isort [] = []
isort [x] = [x]
isort (x:xs) = insert (isort xs)
  where insert [] = [x]
        insert (y:ys)
          | x < y = x : y : ys
          | otherwise = y : insert ys

Solution

  • Okay, let’s take a look. You said you understand most of it already, but I’ll go through the rest of it anyway, for completeness.

    isort :: [Int] -> [Int]
    

    Integer sort takes and returns a list of Int.

    isort [] = []
    

    A base case of the tail-recursive function: the empty list is sorted. In practice, sorting an empty list is trivial and we never care about it, and recursion on a non-empty list will always hit the next base case before this one, so this line is here primarily so the pattern guards will cover every possible list.

    isort [x] = [x]
    

    A singleton list is sorted, tautologically. This is the base case that we will reach in practice.

    isort (x:xs) = insert (isort xs)
    

    Now we get to the line you asked about. This is the recursive case. It calls the insert function, defined below, on the tail of the list. You can think of the head of the list, x, as a hidden parameter of insert.

    This function basically sorts the tail of the list and then inserts the head into that sorted list:

      where
            insert [] = [x]
    

    Inserting x into the empty list produces a singleton list.

            insert (y:ys)
              | x < y = x : y : ys
              | otherwise = y : insert ys
    

    This pair of cases handles the non-trivial insertion with a second recursive function. It assumes that the xs is sorted, and maintains that invariant. As such, it traverses the input list and, the first time it encounters an element greater than x, it splices x in at the head of the list. (One easy optimization would be to also do this if y is equal to x, instead of reading a lot of duplicates—imagine if you’re sorting a list that contains the same element millions of times!)

    While the elements of the input list are less than (or, as written, equal to) x, it creates a new prefix list which will be spliced in front of x.

    So, if you call it with x:xs as [3,2,1], it will first split that into 3:[2,1] and call itself recursively on [2,1]. Then, it will again recurse and generate 2:[1], see that [1] is a singleton, and stop recursing and start unwinding the stack. On the way back out, it’ll see that 2>1 and recurse again, inserting 2 into [] and prepending 1 to get 1:[2]. Then, it gets back to the scope of the original call, and attempts to insert 3 into [1,2]. It now recurses twice, generating the prefix 1:2 and inserts 3 into [] to get 1:2:[3], or [1,2,3], which it returns.

    That implementation is a good start, but an experienced Haskell programmer would probably write a version that’s tail-recursive modulo cons. You’re more likely to see something like this expressed as a strict left fold, where the accumulating parameter is the sorted prefix of the list so far and each step inserts the head of the remaining list into it to obtain a new sorted prefix list. (Well, in the real world they would use a different sorting algorithm that doesn’t take quadratic time, and so would you if you weren’t doing this as an exercise, but the strict-left-fold and tail-recursion-modulo-cons patterns are very useful.) That is more efficient because the compiler can implement it with a single stack frame and a single strict accumulator, which is equivalent to a while loop in imperative languages. A right fold, where you start with the empty list and insert each element in reverse order, would also work, although it doesn’t really have any advantages here.

    Some sample code, which as you can see is very similar:

    import Data.List (foldl')
    
    insertSort :: Ord a => [a] -> [a]
    insertSort = foldl' insert [] where
      insert [] x = [x]
      insert (y:ys) x | x <= y    = x:y:ys
                      | otherwise = y:(insert ys x)
    

    The difference if you walk through the [3,2,1] example again is that it goes from left to right, first generating [3], then inserting 2 to get [2,3], then inserting 1 to get [1,2,3]. The accumulating parameter is strict and all calls it makes are tail calls, so it should be much faster and need less memory at once. The insert helper function itself is tail-recursive modulo cons.