記一次被yield return坑的歷程。
阿新 • • 發佈:2019-02-07
事情的經過是這樣的:
我用C#寫了一個很簡單的一個通過迭代生成序列的函式。
public static IEnumerable<T> Iterate<T>(this Func<T, T> f, T initVal, int length) { Checker.NullCheck(nameof(f), f); Checker.RangeCheck(nameof(length), length, 0, int.MaxValue); var current = initVal; while (--length >= 0) { yield return (current = f(current)); } }
其中NullCheck用於檢查引數是否為null,如果是則丟擲ArgumentNullException異常。
對應的,我寫了如下單元測試程式碼去檢測這個異常。
public void TestIterate() { Func<int, int> f = null; Assert.Throws<ArgumentNullException>(() => f.Iterate(1, 7)); // Other tests}
但是,這個測試出乎意料的fail了。
一開始,我以為是NullCheck函式的問題,可我把NullCheck直接換成了if語句,還是通不過。
後來我在Iterate函式下斷點並除錯。結果偵錯程式根本沒有停在斷點上,直接執行完了測試。
我以為是我測試的方法不對,所以我不斷的修改測試程式碼,甚至還一度以為是.NET的Unit Tests出了bug。
最終,我在這個測試程式碼發現了問題:
Assert.Throws<ArgumentNullException>(() => { var seq = f.Iterate(1, 7); foreach (int ele inseq) Console.WriteLine(ele); });
當我除錯這個測試時,程式停在了我之前在Iterate函式上下的斷點。
於是,我在 var seq = f.Iterate(1, 7); 上下斷點,並逐步執行。這時我發現,當程式執行到 var seq = f.Iterate(1, 7); 時並不會進入Iterate函式;而是當程式執行到foreach語句後才進入。
這就要涉及到yield return的具體工作流程。當函式程式碼中出現yield return,呼叫這個函式會直接返回一個IEnumerable<T>或IEnumerator<T>物件,並不會執行函式體的任何程式碼。這些程式碼都被封裝到了返回物件的內部。它們會在你開始列舉的時候開始執行。
因此,上面兩個Check並不會在函式呼叫時執行,而是在當你開始foreach的時候才執行。
這並不是我想要的結果。我希望在呼叫函式時就檢查引數合法性,如果不合法便直接丟擲異常。
解決這個問題有兩種途徑,一是把它拆成兩個函式:
public static IEnumerable<T> Iterate<T>(this Func<T, T> f, T initVal, int length) { Checker.NullCheck(nameof(f), f); Checker.RangeCheck(nameof(length), length, 0, int.MaxValue); return IterateWithoutCheck(f, initVal, length); } private static IEnumerable<T> IterateWithoutCheck<T>(this Func<T, T> f, T initVal, int length) { var current = initVal; while (--length >= 0) { yield return (current = f(current)); } }
或者,你也可以將這個函式包裝成一個類。
class FunctionIterator<T> : IEnumerable<T> { private readonly Func<T, T> f; private readonly T initVal; private readonly int length; public FunctionIterator(Func<T, T> f, T initVal, int length) { Checker.NullCheck(nameof(f), f); Checker.RangeCheck(nameof(length), length, 0, int.MaxValue); this.f = f; this.initVal = initVal; this.length = length; } public IEnumerator<T> GetEnumerator() { T current = initVal; for (int i = 0; i < length; ++i) yield return (current = f(current)); } System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() { return GetEnumerator(); } }