pytorch: 如何優雅的將 int list 轉成 one-hot形式
阿新 • • 發佈:2019-01-04
雖然 pytorch 已經升級到 0.2.0 了,但是,貌似依舊沒有簡單的 api 來幫助我們快速將 int list 轉成 one-hot。那麼,如何優雅的實現 one-hot 程式碼呢?
def one_hot(ids, out_tensor):
"""
ids: (list, ndarray) shape:[batch_size]
out_tensor:FloatTensor shape:[batch_size, depth]
"""
if not isinstance(ids, (list, np.ndarray)):
raise ValueError("ids must be 1-D list or array")
ids = torch.LongTensor(ids).view(-1,1)
out_tensor.zero_()
out_tensor.scatter_(dim=1, index=ids, src=1.)
# out_tensor.scatter_(1, ids, 1.0)
scatter_
是什麼鬼?
從 value 中拿值,然後根據 dim 和 index 給自己的相應位置填上值
Tensor.scatter_(dim, index, src)
# index: LongTensor
# out[index[i, j], j] = value[i, j] dim=0
# out[i,index[i, j]] = value[i, j]] dim=1
# index 的 shape 可以不和 out 的 shape 一致
# value 也可以是一個 float 值, 也可以是一個 FloatTensor
# 如果 value 是 FloatTensor 的話,那麼shape 需要和 index 保持一致