pytorch編譯配置warp-CTC,warpctc_pytorch 編譯
阿新 • • 發佈:2018-12-21
1.下載壓縮包並解壓,或直接git clone原始碼。
git clone https://github.com/SeanNaren/warp-ctc.git2.執行如下指令進行編譯:
- mv warp-ctc-pytorch_bindings/ warp-ctc
- cd warp-ctc
- mkdir build; cd build
- cmake ..
- make
3.開始安裝:
- cd pytorch_binding
- python setup.py install
4. 安裝結束後,在warp-ctc/pytorch_binding/build目錄下新建一個test.py檔案進行測試。目錄結構如下:
檔案內容為:
- import
torch - from warpctc_pytorch import CTCLoss
- ctc_loss = CTCLoss()
- # expected shape of seqLength x batchSize x alphabet_size
- probs = torch.FloatTensor([[[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]]]).transpose(0, 1).contiguous()
- labels = torch.IntTensor([1, 2])
- label_sizes = torch.IntTensor([2])
- probs_sizes = torch.IntTensor([2
]) - probs.requires_grad_(True) # tells autograd to compute gradients for probs
- cost = ctc_loss(probs, labels, probs_sizes, label_sizes)
- cost.backward()
執行python test.py,若沒有報錯,則證明編譯及執行成功。
6.將其應用於其他專案中。若某專案存在一py格式檔案,存在如下呼叫方式:
則將warp-ctc/pytorch_binding/build/warpctc_pytorch 目錄拷貝至與該py檔案同級的目錄下。
cp -r ~/warp-ctc/ pytorch_binding/build/warpctc_pytorch .
再執行該py檔案即可成功。