1. 程式人生 > >pytorch c++ 多分類問題,計算百分比

pytorch c++ 多分類問題,計算百分比

pytorch c++ 環境搭建,點選

#include <torch/script.h>
#include <ATen/ATen.h>

#include <iostream>
#include <memory>

using namespace std;
using namespace at;

int main(int argc, const char* argv[])
{
	torch::manual_seed(0);
	
    torch::Tensor a = torch::randn({5,2});
    std::cout<< a << std::endl;
    auto i_a = at::argmax(a,1);
    std::cout<< i_a << endl;

    torch::Tensor b = torch::randn({5,2});
    std::cout<< b << std::endl;
    auto i_b = at::argmax(b,1);
    std::cout<< i_b << endl;

    auto result = at::_th_eq(i_a, i_b);
    std::cout<< result << endl;

    auto f_result = result.to(torch::kFloat32);  // 型別轉換,result是unsigned char,必須轉換型別
    auto rate = at::mean(f_result);
    std::cout<< rate << endl;

    std::cout<< "ok\n";
    return 1;
}

編譯並執行

make clean
make
./bin/demo

輸出結果

[email protected]:~/pytorch_c++/test$ ./bin/demo 
 0.8809  2.3786
 0.2025  0.3694
 1.8396 -0.4696
 0.1447  0.7579
 0.0406 -0.7104
[ Variable[CPUFloatType]{5,2} ]
 1
 1
 0
 1
 0
[ Variable[CPULongType]{5} ]
 1.7135  1.3517
-0.6128  1.0147
 1.3197 -0.8938
-0.9867  0.2056
 0.6811 -0.8663
[ Variable[CPUFloatType]{5,2} ]
 0
 1
 0
 1
 0
[ Variable[CPULongType]{5} ]
 0
 1
 1
 1
 1
[ Variable[CPUByteType]{5} ]
0.8
[ Variable[CPUFloatType]{} ]
ok