1. 程式人生 > >Pytorch模型的儲存與讀取方法

Pytorch模型的儲存與讀取方法

方法一(推薦)

只儲存和載入模型的引數

# 儲存模型引數
def save_model(the_model, PATH):
    torch.save(the_model.state_dict(), PATH)
# 載入模型引數
def load_model(PATH):
    the_model = TheModelClass(*args, **kwargs)
    the_model.load_state_dict(torch.load(PATH))

方法二

在這種情況下,序列化的資料被繫結到特定的類和固定的目錄結構,所以當在其他專案中使用時,或者在一些嚴重的重構器之後它可能會以各種方式break。

# 儲存模型引數
def save_model(the_model, PATH):
    torch.save(the_model, PATH)
# 載入模型引數
def load_model(PATH):
    the_model = torch.load(PATH)