Search code examples
algorithmscalalinked-listtail-recursion

Tail recursive solution in Scala for Linked-List chaining


I wanted to write a tail-recursive solution for the following problem on Leetcode -

You are given two non-empty linked lists representing two non-negative integers. The digits are stored in reverse order and each of their nodes contains a single digit. Add the two numbers and return it as a linked list.

You may assume the two numbers do not contain any leading zero, except the number 0 itself.

Example:

*Input: (2 -> 4 -> 3) + (5 -> 6 -> 4)*
*Output: 7 -> 0 -> 8*
*Explanation: 342 + 465 = 807.*

Link to the problem on Leetcode

I was not able to figure out a way to call the recursive function in the last line. What I am trying to achieve here is the recursive calling of the add function that adds the heads of the two lists with a carry and returns a node. The returned node is chained with the node in the calling stack.

I am pretty new to scala, I am guessing I may have missed some useful constructs.

/**
 * Definition for singly-linked list.
 * class ListNode(_x: Int = 0, _next: ListNode = null) {
 *   var next: ListNode = _next
 *   var x: Int = _x
 * }
 */
import scala.annotation.tailrec
object Solution {
  def addTwoNumbers(l1: ListNode, l2: ListNode): ListNode = {
    add(l1, l2, 0)
  }
  //@tailrec
  def add(l1: ListNode, l2: ListNode, carry: Int): ListNode = {
    var sum = 0;
    sum = (if(l1!=null) l1.x else 0) + (if(l2!=null) l2.x else 0) + carry;
    if(l1 != null || l2 != null || sum > 0)
      ListNode(sum%10,add(if(l1!=null) l1.next else null, if(l2!=null) l2.next else null,sum/10))
    else null;
  }
}

Solution

  • You have a couple of problems, which can mostly be reduced as being not idiomatic.

    Things like var and null are not common in Scala and usually, you would use a tail-recursive algorithm to avoid that kind of things.

    Finally, remember that a tail-recursive algorithm requires that the last expression is either a plain value or a recursive call. For doing that, you usually keep track of the remaining job as well as an accumulator.

    Here is a possible solution:

    type Digit = Int // Refined [0..9]
    type Number = List[Digit] // Refined NonEmpty.
    
    def sum(n1: Number, n2: Number): Number = {
      def aux(d1: Digit, d2: Digit, carry: Digit): (Digit, Digit) = {
        val tmp = d1 + d2 + carry
        val d = tmp % 10
        val c = tmp / 10
        
        d -> c
      }
    
      @annotation.tailrec
      def loop(r1: Number, r2: Number, acc: Number, carry: Digit): Number =
        (r1, r2) match {
          case (d1 :: tail1, d2 :: tail2) =>
            val (d, c) = aux(d1, d2, carry)
            loop(r1 = tail1, r2 = tail2, d :: acc, carry = c)
    
          case (Nil, d2 :: tail2) =>
            val (d, c) = aux(d1 = 0, d2, carry)
            loop(r1 = Nil, r2 = tail2, d :: acc, carry = c)
    
          case (d1 :: tail1, Nil) =>
            val (d, c) = aux(d1, d2 = 0, carry)
            loop(r1 = tail1, r2 = Nil, d :: acc, carry = c)
    
          case (Nil, Nil) =>
            acc
        }
    
      loop(r1 = n1, r2 = n2, acc = List.empty, carry = 0).reverse
    }
    

    Now, this kind of recursions tends to be very verbose.
    Usually, the stdlib provide ways to make this same algorithm more concise:

    // This is a solution that do not require the numbers to be already reversed and the output is also in the correct order.
    def sum(n1: Number, n2: Number): Number = {
      val (result, carry) = n1.reverseIterator.zipAll(n2.reverseIterator, 0, 0).foldLeft(List.empty[Digit] -> 0) {
        case ((acc, carry), (d1, d2)) =>
          val tmp = d1 + d2 + carry
          val d = tmp % 10
          val c = tmp / 10
          (d :: acc) -> c
      }
    
    
      if (carry > 0) carry :: result else result
    }