pytorch c++ 多分類問題,計算百分比
阿新 • • 發佈:2019-01-05
pytorch c++ 環境搭建,點選
#include <torch/script.h> #include <ATen/ATen.h> #include <iostream> #include <memory> using namespace std; using namespace at; int main(int argc, const char* argv[]) { torch::manual_seed(0); torch::Tensor a = torch::randn({5,2}); std::cout<< a << std::endl; auto i_a = at::argmax(a,1); std::cout<< i_a << endl; torch::Tensor b = torch::randn({5,2}); std::cout<< b << std::endl; auto i_b = at::argmax(b,1); std::cout<< i_b << endl; auto result = at::_th_eq(i_a, i_b); std::cout<< result << endl; auto f_result = result.to(torch::kFloat32); // 型別轉換,result是unsigned char,必須轉換型別 auto rate = at::mean(f_result); std::cout<< rate << endl; std::cout<< "ok\n"; return 1; }
編譯並執行
make clean
make
./bin/demo
輸出結果
[email protected]:~/pytorch_c++/test$ ./bin/demo 0.8809 2.3786 0.2025 0.3694 1.8396 -0.4696 0.1447 0.7579 0.0406 -0.7104 [ Variable[CPUFloatType]{5,2} ] 1 1 0 1 0 [ Variable[CPULongType]{5} ] 1.7135 1.3517 -0.6128 1.0147 1.3197 -0.8938 -0.9867 0.2056 0.6811 -0.8663 [ Variable[CPUFloatType]{5,2} ] 0 1 0 1 0 [ Variable[CPULongType]{5} ] 0 1 1 1 1 [ Variable[CPUByteType]{5} ] 0.8 [ Variable[CPUFloatType]{} ] ok