pytorch實現帶標籤格式資料的模型訓練
1.訓練資料讀入
注:以下模擬資料,主要講解方法。
標籤資料
下面函式即為實現標籤資料的讀入
def reader(txt):
fh = open(txt)c=0
imgs=[]
class_names=[]
for line in fh.readlines():
if c==0:
class_names=[n.strip() for n in line.rstrip().split(' ')]
else:
cls = line.split()
fn = cls.pop(0)
imgs.append((fn, tuple([float(v) for v in cls])))
c=c+1
return class_names,imgs
其中,返回imgs是標籤元組,即[1,0,0,1],class_names為屬性名,即sex。
如人臉特徵資料,也可以通過reader()讀入。
2.簡單模型設計(以全連層為例)
cmodel=nn.Linear(100, 2) ,(或者nn.Sequential(nn.Linear(100, 2))
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.classify=cmodel
def forward(self, x):
x=self.classify(x)
return x,
3.模型訓練
訓練集讀入
train_data_loader = torch.utils.data.DataLoader( \
ImageFloder(root = "./fea.txt", label = "./label.txt"), batch_size= 2, shuffle= False, num_workers= 4)
其中,root,label分別是特徵與標籤檔案地址, ImageFloder類定義如下:
class ImageFloder(data.Dataset):
def __init__(self, root, label):
self.classes1,self.imgs1 = reader(label)
self.classes2,self.imgs2 = reader(root)
def __getitem__(self, index):
fn1, label1 = self.imgs1[index]
fn2, label2 = self.imgs2[index]
return torch.Tensor(label1),torch.Tensor(label2)
def __len__(self):
return len(self.imgs1)
訓練程式碼詳見專案:
https://github.com/eeric/pytorch-model-training-label