1. 程式人生 > 實用技巧 >Linked List merge sort

Linked List merge sort

Given a singly linked list, how to sort it in O(nlogn) time? We are familiar with merge sort for sorting an array, how to adapt it to sort a list?

  • Recursive approach

The basic idea is to first scan the list, find the middle point and break the list into two, sort two sub-lists recursively and merge them together. Obviously, time complexity would be O(nlogn). What is the space complexity? Since the function is called recursively and it uses stack space, at any time during the execution there are at most O(logn) functions on the stack, therefore the space complexity is O(logn).

The code is shown as follows:

public ListNode sortList(ListNode head) {
    if(head==null || head.next==null)return head;
    ListNode slow=head;
    ListNode fast=head.next;
    while(fast!=null && fast.next!=null){
        slow=slow.next;
        fast=fast.next.next;
    }
    ListNode head2
=slow.next; slow.next=null; head=sortList(head); head2=sortList(head2); return merge(head,head2); } private ListNode merge(ListNode h1,ListNode h2){ ListNode dh=new ListNode(-1); ListNode p=dh; while(h1!=null || h2!=null){ if(h2==null){ p.next=h1; h1
=h1.next; } else if(h1==null){ p.next=h2; h2=h2.next; } else { if(h1.val<=h2.val){ p.next=h1; h1=h1.next; } else { p.next=h2; h2=h2.next; } } p=p.next; } return dh.next; }

Tip: when you sort the sublist recursively, do not forget to update the head of sublist! After the sorting, the head might change!

  • Iterative approach

Scan the list repeatedly, at the first iteration sort every sub list of length 2, at the second iteration sort every sub list of length 4, ..., at the nth iteration sort the every sub list of length 2 to the power n, until 2n+1 is equal to or greater than the length of whole list.

Time complexity is O(nlogn). Space complexity is O(1).

See the code below:

public ListNode sortList(ListNode head) {
    int len=getLength(head);
    if(len<2)return head;
    int subsize=1;
    ListNode dh=new ListNode(-1);
    dh.next=head;
    while(subsize<len){
        ListNode pp=dh;
        ListNode p=dh.next;
        ListNode np=null;
        while(p!=null){
            int x=0;
            while(p!=null && x<subsize){
                x++;
                if(x==subsize)break;
                p=p.next;
            }
            if(p==null || p.next==null)break;
            ListNode q=p.next;
            x=0;
            while(q!=null && x<subsize){
                x++;
                if(x==subsize)break;
                q=q.next;
            }
            if(q==null){
                np=null;
            } else {
                np=q.next;
                q.next=null;
            }
            q=p.next;
            p.next=null;
            p=pp.next;
            pp.next=null;
            while(p!=null || q!=null){
                if(p==null){
                    pp.next=q;q=q.next;
                } else if(q==null){
                    pp.next=p;p=p.next;
                } else {
                    if(p.val<=q.val){
                        pp.next=p;p=p.next;
                    } else {
                        pp.next=q;q=q.next;
                    }
                }
                pp=pp.next;
            }
            p=np;
            pp.next=p;
        }
        subsize*=2;
    }
    return dh.next;
}

private int getLength(ListNode head){
    ListNode p=head;
    int len=0;
    while(p!=null){
        len++;
        p=p.next;
    }
    return len;
}