1. 程式人生 > 實用技巧 >LeetCode - 回溯與剪枝

LeetCode - 回溯與剪枝

回溯演算法的定義:

在包含問題的所有解的解空間樹中,按照深度優先的策略,從根結點出發搜尋解空間樹。演算法搜尋至解空間樹的任一結點時,總是先判斷該結點是否肯定不包含問題的解。如果肯定不包含,則跳過對以該結點為根的子樹的系統搜尋,逐層向其祖先結點回溯。否則,進入該子樹,繼續按深度優先的策略進行搜尋。回溯法在用來求問題的所有解時,要回溯到根,且根結點的所有子樹都已被搜尋遍才結束。而回溯法在用來求問題的任一解時,只要搜尋到問題的一個解就可以結束。

解空間樹:依據待解決問題的特性,用樹結構表示問題的解結構、用葉子表示問題的解的一顆樹。

最初接觸回溯演算法,應該是走迷宮問題用到DFS。對於一些直觀的圖論問題,回溯是很符合常識的思路,“此路不通,原路返回,另尋他路”,這是非常樸素的回溯思路,以致於當時根本沒有意識到這也算是一種演算法思想。

一個錯誤的認識是,回溯只能用於圖論問題,其背後的根源是缺少將一般的問題抽象出解空間樹的能力,以及將一般問題分解為若干狀態的能力。

重新認識回溯思想是遇到全排列問題,這個問題裡沒有明顯的圖和樹,當時甚至沒有往DFS上想,事後看了題解,苦笑、默嘆,以為妙絕。

回溯與剪枝:回溯的本質是列舉和暴力,這意味著他的效率不會高到哪裡去,常常需要剪枝來優化。

全排列

手寫全排列並非難事,對於序列[1, 2, 3],大部分人手動寫出的全排列大概都是這樣的順序

[123, 132, 213, 231, 312, 321]

其內在的邏輯是逐位固定

首先若固定第一位為1,這樣第二位就只剩兩種選擇(請強行聯想樹形結構),即2或3(兩種選擇,兩個子節點)。

若固定第二位為2,則只能固定第三位為3;-> 123

此時已經得到一個排列結果(123),此時向上回溯,看看是否還有別的解(全排列要求得出所有解),取消固定第三位,發現沒有新的解,繼續向上回溯,取消固定第二位,發現當時是有兩種選擇的(2或3),2已經選過,這次可以選3。

若固定第二位為3,則只能固定第三位為2;-> 132

... 遞迴以上過程 ...

難點在於抽象出解空間樹

上面這棵樹的所有葉節點即為最終解。

這道題的題解區liweiwei1419大佬給出的圖描述了更清晰的回溯過程:

此外要注意的是,全排列要求一個數只能用一次,可以設定一個數組來標記對應位置上的數是否已經在當前排列中出現過。

程式碼

class Solution {
    List<List<Integer>> res;
    //標記是否已經用過
    boolean[] vis;
    public List<List<Integer>> permute(int[] nums) {
        res = new ArrayList<>();
        vis = new boolean[nums.length];
        dfs(nums, new ArrayDeque<>(), nums.length);
        return res;
    }
    public void dfs(int[] nums, Deque<Integer> path, int length) {
        //遞迴出口,排列長度 == 序列長度
        if(0 == length) {
            res.add(new ArrayList<>(path));
            return;
        }
        for(int i = 0; i < nums.length; ++i) {
            if(!vis[i]) {
                path.add(nums[i]);
                --length;
                vis[i] = true;
                dfs(nums, path, length);
                //撤銷之前的操作,即回溯
                vis[i] = false;
                ++length;
                path.removeLast();
            }
        }
    }
}

全排列 II

相對上題,唯一的區別是給出的序列中可能包含重複的數字,這意味著按上題方法得出的結果需要去重。

見到有用set去重的,但更好的方法應該是剪枝。

首先考察會出現重複結果的原因。

以序列[1, 2, 1]為例,按照上題逐位固定的思路,當選擇固定第一位時,有1,2,1三種選擇,這裡就可以看到,第一個1和第三個1是相同的,根據全排列的性質,以這兩個1為第一位得到的結果一定會是重複的。

同理,當第一位固定為2,固定第二位時,剩下1,1兩個值相同的選擇,其結果也必定是重複的。

紅框為需要剪枝剪掉的部分。

所以問題變成,有沒有辦法在固定某一位時直接跳過重複的數?

參考C++ unique函式去重的思路——將有序陣列中重複的元素集中到陣列末端然後一併刪除。

這題並非是要把重複元素刪除,而是要識別出重複元素後跳過,顯然,相對於重複值分佈在不同位置的無序陣列,重複值集中在一起的有序陣列要方便得多。

只要在迴圈中加入一個判斷即可

if(i > 0 && nums[i] == nums[i - 1] && vis[i - 1]) {
    break;
}

(陣列排序當然是直接sort...)

完整程式碼

class Solution {
    List<List<Integer>> res;
    boolean[] vis;
    public List<List<Integer>> permuteUnique(int[] nums) {
        res = new ArrayList<>();
        vis = new boolean[nums.length];
        Arrays.sort(nums);
        dfs(nums, nums.length, new ArrayDeque<Integer>());
        return res;
    }
    public void dfs(int[] nums, int length, Deque<Integer> path) {
        if(path.size() == length) {
            res.add(new ArrayList<>(path));
            return;
        }
        for(int i = 0; i < nums.length; ++i) {
            if(vis[i]) {
                continue;
            }
            if(i > 0 && nums[i] == nums[i - 1] && vis[i - 1]) {
                break;
            }
            path.add(nums[i]);
            vis[i] = true;
            dfs(nums, length, path);
            vis[i] = false;
            path.removeLast();
        }
    }
}

子集

起初是把不同長度的子集分開求解的,其實並沒有這個必要。

原先的程式碼

class Solution {
    List<List<Integer>> res;
    public List<List<Integer>> subsets(int[] nums) {
        res = new ArrayList<>();
        //分別求解0~length長度的子集
        for(int i = 0; i <= nums.length; ++i) {
            dfs(nums, i, 0, new ArrayDeque<>());
        }
        return res;
    }
    public void dfs(int[] nums, int length, int begin, Deque<Integer> path) {
        if(0 == length) {
            res.add(new ArrayList<>(path));
            return;
        }
        for(int i = begin; i < nums.length; ++i) {
            path.add(nums[i]);
            --length;
            dfs(nums, length, i + 1, path);
            ++length;
            path.removeLast();
        }
    }
}

在全排列中,遞迴終止的條件是序列長 == 當前排列的長,而在求子集時,並沒有對子集長度的要求,直接新增進結果集即可。

class Solution {
    List<List<Integer>> res;
    public List<List<Integer>> subsets(int[] nums) {
        res = new ArrayList<>();
        dfs(nums, 0, new ArrayDeque<>());
        return res;
    }
    public void dfs(int[] nums, int begin, Deque<Integer> path) {
        res.add(new ArrayList<>(path));
        for(int i = begin; i < nums.length; ++i) {
            path.add(nums[i]);
            dfs(nums, i + 1, path);
            path.removeLast();
        }
    }
}

子集 II

借鑑全排列||的剪枝去重法,幾乎一樣的程式碼

class Solution {
    List<List<Integer>> res;
    public List<List<Integer>> subsetsWithDup(int[] nums) {
        res = new ArrayList<>();
        if(0 == nums.length) {
            return res;
        }
        Arrays.sort(nums);
        Deque<Integer> path = new ArrayDeque<>();
        dfs(nums, 0, path);
        return res;
    }

    public void dfs(int[] nums, int begin, Deque<Integer> path) {
        res.add(new ArrayList<>(path));
        for(int i = begin; i < nums.length; ++i) {
            if(i > begin && nums[i] == nums[i - 1]) {
                continue;
            }
            path.add(nums[i]);
            dfs(nums, i + 1, path);
            path.removeLast();
        }
    }
}

組合

普通解法很容易得出,和上面的題目相比,只是判斷遞迴終止的條件有所變化而已。

class Solution {
    List<List<Integer>> res;
    public List<List<Integer>> combine(int n, int k) {
        res = new ArrayList<>();
        if(n < k || k < 0) {
            return res;
        }
        Deque<Integer> path = new ArrayDeque<>();
        dfs(n, k, 1, path);
        return res;
    }
    public void dfs(int n, int k, int begin, Deque<Integer> path) {
        if(path.size() == k) {
            res.add(new ArrayList(path));
            return;
        }
        for(int i = begin; i <= n; ++i) {
            path.add(i);
            dfs(n, k, i + 1, path);
            path.removeLast();
        }
    }
}

剪枝

依然考慮一般手寫組合的方法,n = 4(序列為[1, 2, 3, 4]),k = 2時,通常會逐個寫出[1, 2], [1, 3], [1, 4], [2, 3], [2, 4]...其實還是逐個固定,組合與全排列的區別僅僅是是否在乎順序而已。按照這個思路,第一位可以分別固定為1,2,3,這裡不會去考慮將第一位固定為4,因為第二位還需要一個數,而4後面已經沒有可以用來加入組合的數了,此時將4加入path是無意義的。這類情況可以被簡單概括為,當按照逐位固定,逐層回溯的思路求解組合時,當需要的組合的上界超過序列的上界,後面的搜尋將變成冗餘,此時就需要剪枝。

根據上述的推理,即當 搜尋起點 - 1 + 組合長度 > 序列上界 (①)時,搜尋可以停止。(-1是因為起點上也有一個數)

考察搜尋過程中幾個變數的意義,i表示搜尋在序列中的進度,path表示已經固定的部分組合,k表示所需的組合長度,n表示序列上界。

不難得出下式:

搜尋起點 = i - path.size()

代入不等式①可得停止搜尋的條件為:

i - path.size() -1 + k > n

即i > n - k + path.size() + 1

所以,只需將搜尋停止的條件改為 i <= n - k + path.size() + 1

class Solution {
    List<List<Integer>> res;
    public List<List<Integer>> combine(int n, int k) {
        res = new ArrayList<>();
        if(n < k || k < 0) {
            return res;
        }
        Deque<Integer> path = new ArrayDeque<>();
        dfs(n, k, 1, path);
        return res;
    }
    public void dfs(int n, int k, int begin, Deque<Integer> path) {
        if(path.size() == k) {
            res.add(new ArrayList(path));
            return;
        }
        //剪枝
        for(int i = begin;  i <= n - k + path.size() + 1; ++i) {
            path.add(i);
            dfs(n, k, i + 1, path);
            path.removeLast();
        }
    }
}

組合總和

思路很簡單,和上題類似,剪枝也很容易想到。

class Solution {
    List<List<Integer>> res;
    int[] nums;
    public List<List<Integer>> combinationSum(int[] candidates, int target) {
        res = new ArrayList<>();
        nums = candidates;
        dfs(0, target, new ArrayDeque<Integer>());
        return res;
    }

    public void dfs(int begin, int target, Deque<Integer> path) {
        //剪枝
        if(0 > target) {
            return;
        }
        if(0 == target) {
            res.add(new ArrayList(path));
            return;
        }
        for(int i = begin; i < nums.length; ++i) {
            path.add(nums[i]);
            target -= nums[i];
            dfs(i, target, path);
            target += nums[i];
            path.removeLast();
        }
    }
}

組合總和 II

沒有什麼新意的剪枝去重。

此外,相比為了去重的剪枝,在組合之和已經大於目標值時立即停止搜尋是更有效的剪枝。

如果只有剪枝去重會T。

class Solution {
    List<List<Integer>> res;
    public List<List<Integer>> combinationSum2(int[] candidates, int target) {
        res = new ArrayList<>();
        Arrays.sort(candidates);
        dfs(candidates, 0, target, new ArrayDeque<Integer>());
        return res;
    }
    public void dfs(int[] candidates, int begin, int target, Deque<Integer> path) {
        if(0 == target) {
            res.add(new ArrayList<>(path));
            return;
        }
        for(int i = begin; i < candidates.length; ++i) {
            //組合總和已經 > target,剪枝
            if(0 > target - candidates[i]) {
                break;
            }
            //剪枝去重
            if(i > begin && candidates[i] == candidates[i - 1]) {
                continue;
            }
            path.add(candidates[i]);
            target -= candidates[i];
            dfs(candidates, i + 1, target, path);
            target += candidates[i];
            path.removeLast();
        }
    }
}

組合總和 III

組合1和組合2的綜合,其實到這裡已經索然無味了...

class Solution {
    List<List<Integer>> res;
    public List<List<Integer>> combinationSum3(int k, int n) {
        res = new ArrayList<>();
        if(n > 9 * k || n < k) {
            return res;
        }
        dfs(k, n, 1, new ArrayDeque<Integer>());
        return res;
    }
    public void dfs(int k, int n, int begin, Deque<Integer> path) {
        if(0 == n && path.size() == k) {
            res.add(new ArrayList<>(path));
            return;
        }
        for(int i = begin; i < 10; ++i) {
            path.add(i);
            n -= i;
            dfs(k, n, i + 1, path);
            n += i;
            path.removeLast();
        }
    }
}

組合總和 Ⅳ簡單回溯會超時,按下不表。

參考資料:liweiwei1419的題解