1. 程式人生 > 其它 >halide程式設計技術指南(連載八)

halide程式設計技術指南(連載八)

技術標籤:深度學習深度學習機器學習神經網路

本文是halide程式設計指南的連載,已同步至公眾號

第13章 元組

// 本課程介紹如何編寫求多個值的函式.
// 在linux系統, 按如下編譯執行:
// g++ lesson_13*.cpp -g -I <path/to/Halide.h> -L <path/to/libHalide.so> -lHalide -lpthread -ldl -o lesson_13 -std=c++11
// LD_LIBRARY_PATH=<path/to/libHalide.so> ./lesson_13
// 在 os x:
// g++ lesson_13*.cpp -g -I <path/to/Halide.h> -L <path/to/libHalide.so> -lHalide -o lesson_13 -std=c++11
// DYLD_LIBRARY_PATH=<path/to/libHalide.dylib> ./lesson_13
// 如果你有halide的原始碼,也可以在原始碼的最頂層目錄,這樣:
//    make tutorial_lesson_13_tuples
#include "Halide.h"
#include <algorithm>
#include <stdio.h>
using namespace Halide;
int main(int argc, char **argv) {

//到目前為止,Funcs(如下面的函式)已經為其域中的每個點計算了一個標量值    
Func single_valued;
    Var x, y;
    single_valued(x, y) = x + y;

    // 編寫返回值集合的Func的一種方法是給返回值新增索引。這就是我們通常處理顏色的方式。例如,下面的Func表示由c索引的每個x,y座標的三個值的集合.
    Func color_image;
    Var c;
    color_image(x, y, c) = select(c == 0, 245,  // Red value
                                  c == 1, 42,   // Green value
                                  132);         // Blue value

    // 由於這種模式經常出現,Halide使用“mux”函式提供了一個syntactic sugar來編寫上面的程式碼,如下所示.
    // color_image(x, y, c) = mux(c, {245, 42, 132});

    // 這種方法通常是方便的,因為它對這個Func的操作變得容易,並且對集合中的每個項都一視同仁:
    Func brighter;
    brighter(x, y, c) = color_image(x, y, c) + 10;

    // 然而,這種方法也不方便有三個原因.
    //
    // 1) Func是在一個無限域上定義的,因此該Func的使用者可以訪問例如color_image(x,y,-17),這不是一個有意義的值,可能表示有bug.
    //
    // 2) 它需要一個select,如果沒有繫結和展開,它會影響效能:
    // brighter.bound(c, 0, 3).unroll(c);
    //
    // 3) 使用此方法,集合中的所有值必須具有相同的型別。雖然上述兩個問題只是不方便,但這是一個硬性的限制,無法用這種方式表達某些東西.

    // 也可以將值的集合表示為func的集合:
    Func func_array[3];
    func_array[0](x, y) = x + y;
    func_array[1](x, y) = sin(x);
    func_array[2](x, y) = cos(y);

    // 這種方法避免了上述三個問題,但引入了一個新的煩惱。因為這些函式是獨立的,所以很難對它們進行排程,以便它們都在x,y上的一個迴圈中一起計算.

    // 第三種方法是將Func定義為對元組求值,而不是對錶達式求值。元組是表示式的固定大小集合。元組中的每個表示式可能有不同的型別。以下函式的計算結果為整數值(x+y)和浮點值(sin(x*y)).
    Func multi_valued;
    multi_valued(x, y) = Tuple(x + y, sin(x * y));

    // 實現一個元組值Func返回一個緩衝區集合。我們稱之為實現。它相當於緩衝區物件的std::vector:
    {
        Realization r = multi_valued.realize(80, 60);
        assert(r.size() == 2);
        Buffer<int> im0 = r[0];
        Buffer<float> im1 = r[1];
        assert(im0(30, 40) == 30 + 40);
        assert(im1(30, 40) == sinf(30 * 40));
    }

    // 所有元組元素在同一迴圈巢狀中的同一域上一起計算,但儲存在不同的分配中。上面的C++程式碼是:
    {
        int multi_valued_0[80 * 60];
        float multi_valued_1[80 * 60];
        for (int y = 0; y < 80; y++) {
            for (int x = 0; x < 60; x++) {
                multi_valued_0[x + 60 * y] = x + y;
                multi_valued_1[x + 60 * y] = sinf(x * y);
            }
        }
    }

    // 在提前編譯時,元組值Func計算為多個不同的輸出halide_buffer_t結構體。它們依次出現在函式簽名的末尾:
    // int multi_valued(...input buffers and params...,
    //                  halide_buffer_t *output_1, halide_buffer_t *output_2);

    // 您可以通過向元組建構函式傳遞多個表示式來構造元組,就像我們上面所做的那樣。也許更優雅,您還可以利用C++ 11初始化列表,只需在括號中包含ExpRs即可:
    Func multi_valued_2;
    multi_valued_2(x, y) = {x + y, sin(x * y)};

    // 對多值函式的呼叫不能視為表示式。以下是語法錯誤:
    // Func consumer;
    // consumer(x, y) = multi_valued_2(x, y) + 10;

    // 相反,您必須用方括號索引一個元組來檢索各個表示式:
    Expr integer_part = multi_valued_2(x, y)[0];
    Expr floating_part = multi_valued_2(x, y)[1];
    Func consumer;
    consumer(x, y) = {integer_part + 10, floating_part + 10.0f};

    // 元組約化.
    {
        // 元組在歸約中特別有用,因為它們允許歸約在其域中執行時保持複雜狀態。最簡單的例子是argmax.

        // 首先,我們建立一個緩衝區來接管argmax.
        Func input_func;
        input_func(x) = sin(x);
        Buffer<float> input = input_func.realize(100);

        // 然後我們定義一個二值元組來跟蹤最大值的索引和值本身.
        Func arg_max;

        // 純定義.
        arg_max() = {0, input(0)};

        // 更新.
        RDom r(1, 99);
        Expr old_index = arg_max()[0];
        Expr old_max = arg_max()[1];
        Expr new_index = select(old_max < input(r), r, old_index);
        Expr new_max = max(input(r), old_max);
        arg_max() = {new_index, new_max};

        // 等效C:
        int arg_max_0 = 0;
        float arg_max_1 = input(0);
        for (int r = 1; r < 100; r++) {
            int old_index = arg_max_0;
            float old_max = arg_max_1;
            int new_index = old_max < input(r) ? r : old_index;
            float new_max = std::max(input(r), old_max);
            // 在元組更新定義中,所有的載入和計算都是在任何儲存之前完成的,因此所有元組元素都是相對於對同一Func的遞迴呼叫進行原子更新的.
            arg_max_0 = new_index;
            arg_max_1 = new_max;
        }

        // 讓我們驗證halide和C++找到相同的最大值和索引.
        {
            Realization r = arg_max.realize();
            Buffer<int> r0 = r[0];
            Buffer<float> r1 = r[1];
            assert(arg_max_0 == r0(0));
            assert(arg_max_1 == r1(0));
        }

        // halide提供argmax和argmin作為內建函式,類似於總和、乘積、最大值和最小值。它們返回一個元組,元組由對應於該值的歸約域中的點和值本身組成。對於一個tie,它們返回找到的第一個值。我們將在下一節中使用其中一個.
    }

    // 使用者定義的元組.
    {
        // 元組也是表示複合物件(如複數)的方便方法。定義一個可以與元組進行轉換的物件是用使用者定義的型別擴充套件Halide的型別系統的一種方法.
        struct Complex {
            Expr real, imag;

            // 從元組構建
            Complex(Tuple t)
                : real(t[0]), imag(t[1]) {
            }

            // 從一對 Exprs構建
            Complex(Expr r, Expr i)
                : real(r), imag(i) {
            }

            // 通過將Func當作元組來構造對它的呼叫
            Complex(FuncRef t)
                : Complex(Tuple(t)) {
            }

            // 轉換為元組
            operator Tuple() const {
                return {real, imag};
            }

            // 複合加法
            Complex operator+(const Complex &other) const {
                return {real + other.real, imag + other.imag};
            }

            // 複數乘法
            Complex operator*(const Complex &other) const {
                return {real * other.real - imag * other.imag,
                        real * other.imag + imag * other.real};
            }

            // 復振幅,效率平方
            Expr magnitude_squared() const {
                return real * real + imag * imag;
            }

            // 其他複雜的操作符會在這裡。對於這個例子,上面的內容就足夠了.
        };

        // 讓我們使用複雜結構來計算一個Mandelbrot集.
        Func mandelbrot;

        // 函式中x,y座標對應的初始復值.
        Complex initial(x / 15.0f - 2.5f, y / 6.0f - 2.0f);

        // 純定義.
        Var t;
        mandelbrot(x, y, t) = Complex(0.0f, 0.0f);

        // 我們將使用更新定義來執行12個步驟.
        RDom r(1, 12);
        Complex current = mandelbrot(x, y, r - 1);

        // 下面一行使用我們上面定義的複數乘法和加法.
        mandelbrot(x, y, r) = current * current + initial;

        // 我們將使用另一個元組來計算迭代次數,其中值首先轉義半徑為4的圓。這可以表示為布林表示式的argmin—我們希望第一次給定布林表示式的索引為false(我們認為false小於true)。argmax將返回表示式第一次為真時的索引.

        Expr escape_condition = Complex(mandelbrot(x, y, r)).magnitude_squared() < 16.0f;
        Tuple first_escape = argmin(escape_condition);

        // 我們只需要索引,不需要值,但是argmin返回這兩個值,所以我們將使用方括號索引argmin元組表示式,以獲得表示索引的表示式.
        Func escape;
        escape(x, y) = first_escape[0];

        // 實現流水線並以ascii格式列印結果.
        Buffer<int> result = escape.realize(61, 25);
        const char *code = " .:-~*={}&%#@";
        for (int y = 0; y < result.height(); y++) {
            for (int x = 0; x < result.width(); x++) {
                printf("%c", code[result(x, y)]);
            }
            printf("\n");
        }
    }

    printf("Success!\n");

return 0;
}

第14章 型別系統

// 這一課更精確地描述halide的型別系統.
// linux, 
// g++ lesson_14*.cpp -g -I <path/to/Halide.h> -L <path/to/libHalide.so> -lHalide -lpthread -ldl -o lesson_14 -std=c++11
// LD_LIBRARY_PATH=<path/to/libHalide.so> ./lesson_14
// os x:// g++ lesson_14*.cpp -g -I <path/to/Halide.h> -L <path/to/libHalide.so> -lHalide -o lesson_14 -std=c++11
// DYLD_LIBRARY_PATH=<path/to/libHalide.dylib> ./lesson_14
// 在原始碼樹,可以執行
//    make tutorial_lesson_14_types
#include "Halide.h"
#include <stdio.h>
using namespace Halide;
// 此函式用於在本課程結束時演示通用程式碼.
Expr average(Expr a, Expr b);
int main(int argc, char **argv) {

    // 所有表示式都有一個標量型別,所有函式的計算結果都是一個或多個標量型別。Halide中的標量型別是各種位寬度的無符號整數、同一組位寬度的有符號整數、單精度和雙精度的浮點數以及不透明控制代碼(相當於void*)。以下陣列包含所有合法型別.

    Type valid_halide_types[] = {
        UInt(8), UInt(16), UInt(32), UInt(64),
        Int(8), Int(16), Int(32), Int(64),
        Float(32), Float(64), Handle()};

    // 構造和檢查型別.
    {
        // 可以通過程式設計方式檢查halide型別的屬性。當編寫具有ExPR引數的C++函式時,這是有用的,並且您希望檢查它們的型別。:
        assert(UInt(8).bits() == 8);
        assert(Int(8).is_int());

        // 也可以通過程式設計方式將型別構造為其他型別的函式.
        Type t = UInt(8);
        t = t.with_bits(t.bits() * 2);
        assert(t == UInt(16));

        // 或者從C++標量型別構造型別
        assert(type_of<float>() == Float(32));

        // 型別結構也能夠表示向量型別,但這是為Halide的內部使用而保留的。應該使用Func::vectorize對程式碼進行向量化,而不是試圖直接構造矢量表達式。如果以程式設計方式操作低halide程式碼,可能會遇到向量型別,但這是一個高階主題 (檢視Func::add_custom_lowering_pass).

        // 您可以查詢任何鹵化物表示式的型別。Expr代表Var是Int(32)型別:
        Var x;
        assert(Expr(x).type() == Int(32));

        // halide中大多數transcendental 函式將輸入型別轉換為Float(32) 並返回Float(32):
        assert(sin(x).type() == Float(32));

        // 可以通過如下操作將 Expr從一個型別轉換為另一個型別:
        assert(cast(UInt(8), x).type() == UInt(8));

        // 這也是以C++形式的模板形式出現的.
        assert(cast<uint8_t>(x).type() == UInt(8));

        // 您還可以查詢任何已定義的Func以獲取其生成的型別.
        Func f1;
        f1(x) = cast<uint8_t>(x);
        assert(f1.output_types()[0] == UInt(8));

        Func f2;
        f2(x) = {x, sin(x)};
        assert(f2.output_types()[0] == Int(32) &&
               f2.output_types()[1] == Float(32));
    }

    // 型別提升規則.
    {
        // 當您組合不同型別的表示式(例如使用“+”、“*”等)時,Halide使用型別提升規則系統。這些與C的規則不同。為了演示這些,我們將對每種型別進行一些表示式.
        Var x;
        Expr u8 = cast<uint8_t>(x);
        Expr u16 = cast<uint16_t>(x);
        Expr u32 = cast<uint32_t>(x);
        Expr u64 = cast<uint64_t>(x);
        Expr s8 = cast<int8_t>(x);
        Expr s16 = cast<int16_t>(x);
        Expr s32 = cast<int32_t>(x);
        Expr s64 = cast<int64_t>(x);
        Expr f32 = cast<float>(x);
        Expr f64 = cast<double>(x);

        // 規則如下所示,並按以下順序應用.

        // 1) 對Handle()型別的表示式強制轉換或使用算術運算子是錯誤的.

        // 2) 如果型別相同,則不會發生型別轉換.
        for (Type t : valid_halide_types) {
            // 跳過控制代碼型別.
            if (t.is_handle()) continue;
            Expr e = cast(t, x);
            assert((e + e).type() == e.type());
        }

        // 3) 如果一個型別是float而另一個不是,那麼non-float引數將提升為float(可能導致大整數的精度損失).
        assert((u8 + f32).type() == Float(32));
        assert((f32 + s64).type() == Float(32));
        assert((u16 + f64).type() == Float(64));
        assert((f64 + s32).type() == Float(64));

        // 4) 如果這兩種型別都是float,則較窄的引數將提升為較寬的位寬度.
        assert((f64 + f32).type() == Float(64));

        // 上面的規則處理所有浮點情況。以下三條規則處理整數情況.

        // 5) 如果其中一個引數是C++ int,而另一個是halide::ExPR,則int被強制轉換為表示式的型別。.
        assert((u32 + 3).type() == UInt(32));
        assert((3 + s16).type() == Int(16));

        // 如果此規則會導致整數溢位,則Halide將觸發錯誤,例如,取消對以下行的註釋將導致此程式以錯誤終止.
        // Expr bad = u8 + 257;

        // 6) 如果兩種型別都是無符號整數,或者兩種型別都是有符號整數,則較窄的引數將提升為較寬的型別.
        assert((u32 + u8).type() == UInt(32));
        assert((s16 + s64).type() == Int(64));

        // 7) 如果一種型別是有符號的,而另一種是無符號的,則兩個引數都將提升為一個有符號整數,其寬度為兩個位寬度中的較大值.
        assert((u8 + s32).type() == Int(32));
        assert((u32 + s8).type() == Int(32));

        // 注意,在位元寬度相同的情況下,這可能會悄悄地溢位無符號型別.
        assert((u32 + s32).type() == Int(32));

        // 以這種方式將無符號表達式轉換為更寬的有符號型別時,首先將其擴充套件為更寬的無符號型別(零擴充套件),然後重新解釋為有符號整數。即,將UInt(8)值255轉換為Int(32)產生255,而不是-1.
        int32_t result32 = evaluate<int>(cast<int32_t>(cast<uint8_t>(255)));
        assert(result32 == 255);

        // 當使用強制轉換運算子將有符號型別顯式轉換為更寬的無符號型別(型別提升規則不會自動執行此操作)時,首先將其轉換為更寬的有符號型別(符號擴充套件),然後重新解釋為無符號整數。即,將Int(8)值-1轉換為UInt(16)產生65535,而不是255.
        uint16_t result16 = evaluate<uint16_t>(cast<uint16_t>(cast<int8_t>(-1)));
        assert(result16 == 65535);
    }

    // Handle()型別.
    {
        // 控制代碼用於表示不透明指標。將type_of應用於任何指標型別將返回Handle()
        assert(type_of<void *>() == Handle());
        assert(type_of<const char *const **>() == Handle());

        // 不管編譯目標是什麼,控制代碼始終儲存為64位.
        assert(Handle().bits() == 64);

        // Handle型別的Expr的主要用途是將它通過halide傳遞給其他外部程式碼.
    }

    // 通用程式碼.
    {
        // Type在Halide中的主要顯式用法是編寫由Type引數化的Halide程式碼。在C++中,你可以用模板來完成這個操作。在halide中,不需要 —— 可以在C++執行時動態地檢查和修改型別。下面定義的函式用來平均任意相等數值型別的兩個表示式.
        Var x;
        assert(average(cast<float>(x), 3.0f).type() == Float(32));
        assert(average(x, 3).type() == Int(32));
        assert(average(cast<uint8_t>(x), cast<uint8_t>(3)).type() == UInt(8));
    }

    printf("Success!\n");

    return 0;}

Expr average(Expr a, Expr b) {
    // 型別必須匹配.
    assert(a.type() == b.type());

    // 對於浮點型別:
    if (a.type().is_float()) {
        // 由於上面的規則3,“2”將升級為浮點型別.
        return (a + b) / 2;
    }

    // 對於整數型別,我們必須在更寬的型別中計算中間值以避免溢位.
    Type narrow = a.type();
    Type wider = narrow.with_bits(narrow.bits() * 2);
    a = cast(wider, a);
    b = cast(wider, b);
    return cast(narrow, (a + b) / 2);
}