Search code examples
pythonalgorithmsortingmergesort

Sort a linked list in O(n log n) time using merge sort


This is actually a algorithm practice of Leetcode. and below is my code:

class ListNode:
    def __init__(self, x, next=None):
        self.val = x
        self.next = next

class Solution:
    # @param head, a ListNode
    # @return a ListNode
    def sortList(self, head):
        if head is None or head.next is None:
            return head
        first, second = self.divide_list(head)
        self.sortList(first)
        self.sortList(second)
        return self.merge_sort(first, second)

    def merge_sort(self, first, second):
        if first is None:
            return second
        if second is None:
            return first
        left, right = first, second
        if left.val <= right.val:
            current, head = left, first
            left = left.next
        else:
            current, head = right, second
            right = right.next
        while left and right:
            if left.val <= right.val:
                current.next = left
                left = left.next
            else:
                current.next = right
                right = right.next
            current = current.next

        if left is not None:
            current.next = left
        if right is not None:
            current.next = right
        return head

    def divide_list(self, head):
        fast, slow = head.next, head
        while fast.next:
            slow = slow.next
            fast = fast.next
            if fast.next:
                fast = fast.next
        second_part = slow.next
        slow.next = None
        return head, second_part

the idea is quite straightforward, just the basic concept of Merge Sort. but the result seems incorrect and the running time is cost too much that can't pass the judgement of Leetcode(Time Limit Exceeded, BUT why not O(nlog(n))?). and below is my test code:

Basic test:

c= ListNode(3, ListNode(1, ListNode(2, ListNode(4))))
result =Solution().sortList(c)
while result:
    print result.val
    result = result.next # result: 2, 3, 4 which missing 1

anyone have idea to optimize this code?


Solution

  • The offending lines are

            self.sortList(first)
            self.sortList(second)
    

    The problem is that the list head may change after sorting. The fix is

            first = self.sortList(first)
            second = self.sortList(second)
    

    As a general hint, I would suggest that you employ sentinel nodes to cut down on the number of special cases.