1. 程式人生 > >訓練一個數據不夠多的資料集是什麼體驗?

訓練一個數據不夠多的資料集是什麼體驗?

摘要:這裡介紹其中一種帶標籤擴充資料集的方法。

前言

前一段時間接觸了幾位使用者提的問題,發現很多人在使用訓練的時候,給的資料集寥寥無幾,有一些甚至一類只有5張圖片。modelarts平臺雖然給出了每類5張圖片就能訓練的限制,但是這種限制對一個工業級的應用場景往往是遠遠不夠的。所以聯絡了使用者希望多增加一些圖片,增加幾千張圖片訓練。但是使用者後面反饋,標註的工作量實在是太大了。我思忖了一下,分析了一下他應用的場景,做了一些策略變化。這裡介紹其中一種帶標籤擴充資料集的方法。

資料集情況

資料集由於屬於使用者資料,不能隨便展示,這裡用一個可以展示的開源資料集來替代。首先,這是一個分類的問題,需要檢測出工業零件表面的瑕疵,判斷是否為殘次品,如下是樣例圖片:

這是兩塊太陽能電板的表面,左側是正常的,右側是有殘缺和殘次現象的,我們需要用一個模型來區分這兩類的圖片,幫助定位哪些太陽能電板存在問題。左側的正常樣本754張,右側的殘次樣本358張,驗證集同樣,正常樣本754張,殘次樣本357張。總樣本在2000張左右,對於一般工業要求的95%以上準確率模型而言屬於一個非常小的樣本。先直接拿這個資料集用Pytorch載入imagenet的resnet50模型訓練了一把,整體精度ACC在86.06%左右,召回率正常類為97.3%,但非正常類為62.9%,還不能達到使用者預期。

當要求使用者再多收集,至少擴充到萬級的資料集的時候,使用者提出,收集資料要經過處理,還要標註,很麻煩,問有沒有其他的辦法可以節省一些工作量。這可一下難倒了我,資料可是深度學習訓練的靈魂,這可咋整啊。

仔細思考了一陣子,想到modelarts上有智慧標註然後人工校驗的功能,就讓使用者先試著體驗一下這個功能。我這邊拿他給我的資料集想想辦法。查了些資料,小樣本學習few-shot fewshot learning (FSFSL)的常見方法,基本都是從兩個方向入手。一是資料本身,二是從模型訓練本身,也就是對影象提取的特徵做文章。這裡想著從資料本身入手。

首先觀察資料集,都是300*300的灰度影象,而且都已太陽能電板表面的正面俯視為整張圖片。這屬於預先處理的很好的圖片。那麼針對這種圖片,翻轉映象對圖片整體結構影響不大,所以我們首先可以做的就是flip操作,增加資料的多樣性。flip效果如下:

這樣資料集就從1100張擴增到了2200張,還是不是很多,但是直接觀察資料集已經沒什麼太好的擴充辦法了。這時想到用Modelarts模型評估的功能來評估一下模型對資料的泛化能力。這裡呼叫了提供的SDK:deep_moxing.model_analysis下面的analyse介面。

def validate(val_loader, model, criterion, args):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, top1, top5],
        prefix='Test: ')
    pred_list = []
    target_list = []
    # switch to evaluate mode
    model.eval()
    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            if args.gpu is not None:
                images = images.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)

            # compute output
            output = model(images)
            loss = criterion(output, target)
            # 獲取logits輸出結果pred和實際目標的結果target
            pred_list += output.cpu().numpy()[:, :2].tolist()
            target_list += target.cpu().numpy().tolist()
            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5), i=i)
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)
        # TODO: this should also be done with the ProgressMeter
        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
              .format(top1=top1, top5=top5))
    # 獲取圖片的儲存路徑name
    name_list = val_loader.dataset.samples
    for idx in range(len(name_list)):
        name_list[idx] = name_list[idx][0]
    analyse(task_type='image_classification', save_path='/home/image_labeled/',
            pred_list=pred_list, label_list=target_list, name_list=name_list)
    return top1.avg

上段程式碼大部分都是Pytorch訓練ImageNet中的驗證部分程式碼,需要獲取三個list,模型pred直接結果logits、圖片實際類別target和圖片儲存路徑name。然後按如上的呼叫方法呼叫analyse介面,會在save_path的目錄下生成一個json檔案,放到Modelarts訓練輸出目錄裡,就能在評估結果裡看到對模型的分析結果。我這裡是線下生成的json檔案再上傳到線上看視覺化結果。關於敏感度分析結果如下:

這幅圖的意思是,不同的特徵值範圍圖片分別測試的精度是多少。比如亮度敏感度分析的第一項0%-20%,可以理解為,在圖片亮度較低的場景下對與0類和其他亮度條件的圖片相比,精度要低很多。整體來看,主要是為了檢測1類,1類在圖片的亮度和清晰度兩項上顯得都很敏感,也就是模型不能很好地處理圖片的這兩項特徵變化的圖片。那這不就是我要擴增資料集的方向嗎?

好的,那麼我就試著直接對全量的資料集做了擴增,得到一個正常類2210張,瑕疵類1174張圖片的資料集,用同樣的策略扔進pytorch中訓練,得到的結果:

怎麼回事,和設想的不太一樣啊。。。

重新分析一下資料集,我突然想到,這種工業類的資料集往往都存在一個樣本不均勻的問題,這裡雖然接近2:1,但是檢測的要求針對有瑕疵的類別的比較高,應該讓模型傾向於有瑕疵類去學習,而且看到1類的也就是有瑕疵類的結果比較敏感,所以其實還是存在樣本不均衡的情況。由此後面的這兩種增強方法只針對了1類也就是有問題的破損類做,最終得到3000張左右,1508張正常類圖片,1432張有瑕疵類圖片,這樣樣本就相對平衡了。用同樣的策略扔進resnet50中訓練。最終得到的精度資訊:

可以看到,同樣在驗證集,正常樣本754張,殘次樣本357張的樣本上,Acc1的精度整體提升了接近3%,重要指標殘次類的recall提升了8.4%!嗯,很不錯。所以直接擴充資料集的方法很有效,而且結合模型評估能讓我參考哪些擴增的方法是有意義的。當然還有很重要的一點,要排除原始資料集存在的問題,比如這裡存在的樣本不均衡問題,具體情況具體分析,這個擴增的方法就會變得簡單實用。

之後基於這個實驗的結果和資料集。給幫助使用者改了一些訓練策略,換了個更厲害的網路,就達到了使用者的要求,當然這都是定製化分析的結果,這裡不詳細展開說明了,或者會在以後的部落格中更新。

引用資料集來自:

Buerhop-Lutz, C.; Deitsch, S.; Maier, A.; Gallwitz, F.; Berger, S.; Doll, B.; Hauch, J.; Camus, C. & Brabec, C. J. A Benchmark for Visual Identification of Defective Solar Cells in Electroluminescence Imagery. European PV Solar Energy Conference and Exhibition (EU PVSEC), 2018. DOI: 10.4229/35thEUPVSEC20182018-5CV.3.15

Deitsch, S.; Buerhop-Lutz, C.; Maier, A. K.; Gallwitz, F. & Riess, C. Segmentation of Photovoltaic Module Cells in Electroluminescence Images. CoRR, 2018, abs/1806.06530

Deitsch, S.; Christlein, V.; Berger, S.; Buerhop-Lutz, C.; Maier, A.; Gallwitz, F. & Riess, C. Automatic classification of defective photovoltaic module cells in electroluminescence images. Solar Energy, Elsevier BV, 2019, 185, 455-468. DOI: 10.1016/j.solener.2019.02.067

 

點選關注,第一時間瞭解華為雲新鮮技