1. 程式人生 > >pytorch下使用LSTM神經網路寫詩

pytorch下使用LSTM神經網路寫詩

在pytorch下,以數萬首唐詩為素材,訓練雙層LSTM神經網路,使其能夠以唐詩的方式寫詩。

程式碼結構分為四部分,分別為

1.model.py,定義了雙層LSTM模型

2.data.py,定義了從網上得到的唐詩資料的處理方法

3.utlis.py 定義了損失視覺化的函式

4.main.py定義了模型引數,以及訓練、唐詩生成函式。

參考:電子工業出版社的《深度學習框架PyTorch:入門與實踐》第九章

main程式碼及註釋如下

import sys, os
import torch as t
from data import get_data
from model import PoetryModel
from torch import nn
from torch.autograd import Variable
from utils import Visualizer
import tqdm
from torchnet import meter
import ipdb

class Config(object):
	data_path = 'data/'
	pickle_path = 'tang.npz'
	author = None
	constrain = None
	category = 'poet.tang' #or poet.song
	lr = 1e-3
	weight_decay = 1e-4
	use_gpu = True
	epoch = 20
	batch_size = 128
	maxlen = 125
	plot_every = 20
	#use_env = True #是否使用visodm
	env = 'poety' 
	#visdom env
	max_gen_len = 200
	debug_file = '/tmp/debugp'
	model_path = None
	prefix_words = '細雨魚兒出,微風燕子斜。' 
	#不是詩歌組成部分,是意境
	start_words = '閒雲潭影日悠悠' 
	#詩歌開始
	acrostic = False 
	#是否藏頭
	model_prefix = 'checkpoints/tang' 
	#模型儲存路徑
opt = Config()

def generate(model, start_words, ix2word, word2ix, prefix_words=None):
	'''
	給定幾個詞,根據這幾個詞接著生成一首完整的詩歌
	'''
	results = list(start_words)
	start_word_len = len(start_words)
	# 手動設定第一個詞為<START>
	# 這個地方有問題,最後需要再看一下
	input = Variable(t.Tensor([word2ix['<START>']]).view(1,1).long())
	if opt.use_gpu:input=input.cuda()
	hidden = None
	
	if prefix_words:
		for word in prefix_words:
			output,hidden = model(input,hidden)
			# 下邊這句話是為了把input變成1*1?
			input = Variable(input.data.new([word2ix[word]])).view(1,1)
	for i in range(opt.max_gen_len):
		output,hidden = model(input,hidden)
		
		if i<start_word_len:
			w = results[i]
			input = Variable(input.data.new([word2ix[w]])).view(1,1)
		else:
			top_index = output.data[0].topk(1)[1][0]
			w = ix2word[top_index]
			results.append(w)
			input = Variable(input.data.new([top_index])).view(1,1)
		if w=='<EOP>':
			del results[-1] #-1的意思是倒數第一個
			break
	return results

def gen_acrostic(model,start_words,ix2word,word2ix, prefix_words = None):
    '''
    生成藏頭詩
    start_words : u'深度學習'
    生成:
    深木通中嶽,青苔半日脂。
    度山分地險,逆浪到南巴。
    學道兵猶毒,當時燕不移。
    習根通古岸,開鏡出清羸。
    '''
    results = []
    start_word_len = len(start_words)
    input = Variable(t.Tensor([word2ix['<START>']]).view(1,1).long())
    if opt.use_gpu:input=input.cuda()
    hidden = None
    
    index=0 # 用來指示已經生成了多少句藏頭詩
    # 上一個詞
    pre_word='<START>'

    if prefix_words:
        for word in prefix_words:
            output,hidden = model(input,hidden)
            input = Variable(input.data.new([word2ix[word]])).view(1,1)

    for i in range(opt.max_gen_len):
        output,hidden = model(input,hidden)
        top_index  = output.data[0].topk(1)[1][0]
        w = ix2word[top_index]

        if (pre_word  in {u'。',u'!','<START>'} ):
            # 如果遇到句號,藏頭的詞送進去生成

            if index==start_word_len:
                # 如果生成的詩歌已經包含全部藏頭的詞,則結束
                break
            else:  
                # 把藏頭的詞作為輸入送入模型
                w = start_words[index]
                index+=1
                input = Variable(input.data.new([word2ix[w]])).view(1,1)    
        else:
            # 否則的話,把上一次預測是詞作為下一個詞輸入
            input = Variable(input.data.new([word2ix[w]])).view(1,1)
        results.append(w)
        pre_word = w
    return results

def train(**kwargs):
	
	for k,v in kwargs.items():
		setattr(opt,k,v) #設定apt裡屬性的值
	vis = Visualizer(env=opt.env)
	
	#獲取資料
	data, word2ix, ix2word = get_data(opt) #get_data是data.py裡的函式
	data = t.from_numpy(data)
	#這個地方出錯了,是大寫的L
	dataloader = t.utils.data.DataLoader(data, 
					batch_size = opt.batch_size,
					shuffle = True,
					num_workers = 1) #在python裡,這樣寫程式可以嗎?
    #模型定義
	model = PoetryModel(len(word2ix), 128, 256)
	optimizer = t.optim.Adam(model.parameters(), lr=opt.lr)
	criterion = nn.CrossEntropyLoss()
    
	if opt.model_path:
		model.load_state_dict(t.load(opt.model_path))
	if opt.use_gpu:
		model.cuda()
		criterion.cuda()
		
	#The tnt.AverageValueMeter measures and returns the average value 
	#and the standard deviation of any collection of numbers that are 
	#added to it. It is useful, for instance, to measure the average 
	#loss over a collection of examples.

    #The add() function expects as input a Lua number value, which 
    #is the value that needs to be added to the list of values to 
    #average. It also takes as input an optional parameter n that 
    #assigns a weight to value in the average, in order to facilitate 
    #computing weighted averages (default = 1).

    #The tnt.AverageValueMeter has no parameters to be set at initialization time. 
	loss_meter = meter.AverageValueMeter()
	
	for epoch in range(opt.epoch):
		loss_meter.reset()
		for ii,data_ in tqdm.tqdm(enumerate(dataloader)):
			#tqdm是python中的進度條
			#訓練
			data_ = data_.long().transpose(1,0).contiguous()
			#上邊一句話,把data_變成long型別,把1維和0維轉置,把記憶體調成連續的
			if opt.use_gpu: data_ = data_.cuda()
			optimizer.zero_grad()
			input_, target = Variable(data_[:-1,:]), Variable(data_[1:,:])
			#上邊一句,將輸入的詩句錯開一個字,形成訓練和目標
			output,_ = model(input_)
			loss = criterion(output, target.view(-1))
			loss.backward()
			optimizer.step()
			
			loss_meter.add(loss.data[0]) #為什麼是data[0]?
			
			#視覺化用到的是utlis.py裡的函式
			if (1+ii)%opt.plot_every ==0:
				
				if os.path.exists(opt.debug_file):
					ipdb.set_trace()
				vis.plot('loss',loss_meter.value()[0])
				
				# 下面是對目前模型情況的測試,詩歌原文
				poetrys = [[ix2word[_word] for _word in data_[:,_iii]] 
									for _iii in range(data_.size(1))][:16]
				#上面句子嵌套了兩個迴圈,主要是將詩歌索引的前十六個字變成原文
				vis.text('</br>'.join([''.join(poetry) for poetry in 
				poetrys]),win = u'origin_poem')
				gen_poetries = []
				#分別以以下幾個字作為詩歌的第一個字,生成8首詩
				for word in list(u'春江花月夜涼如水'):
					gen_poetry = ''.join(generate(model,word,ix2word,word2ix))
					gen_poetries.append(gen_poetry)
				vis.text('</br>'.join([''.join(poetry) for poetry in 
				gen_poetries]), win = u'gen_poem')
		t.save(model.state_dict(), '%s_%s.pth' %(opt.model_prefix,epoch))

def gen(**kwargs):
	'''
	提供命令列介面,用以生成相應的詩
	'''
	
	for k,v in kwargs.items():
		setattr(opt,k,v)
	data, word2ix, ix2word = get_data(opt)
	model = PoetryModel(len(word2ix), 128, 256)
	map_location = lambda s,l:s
	# 上邊句子裡的map_location是在load裡用的,用以載入到指定的CPU或GPU,
	# 上邊句子的意思是將模型載入到預設的GPU上
	state_dict = t.load(opt.model_path, map_location = map_location)
	model.load_state_dict(state_dict)
	
	if opt.use_gpu:
		model.cuda()
	if sys.version_info.major == 3:
		if opt.start_words.insprintable():
			start_words = opt.start_words
			prefix_words = opt.prefix_words if opt.prefix_words else None
		else:
			start_words = opt.start_words.encode('ascii',\
			'surrogateescape').decode('utf8')
			prefix_words = opt.prefix_words.encode('ascii',\
			'surrogateescape').decode('utf8') if opt.prefix_words else None
		start_words = start_words.replace(',',u',')\
											.replace('.',u'。')\
											.replace('?',u'?')
		gen_poetry = gen_acrostic if opt.acrostic else generate
		result = gen_poetry(model,start_words,ix2word,word2ix,prefix_words)
		print(''.join(result))
if __name__ == '__main__':
	import fire
	fire.Fire()

以上程式碼給我一些經驗,

1. 瞭解python的程式設計方式,如空格、換行等;進一步瞭解python的各個基本模組;

2. 可能出的錯誤:函式名寫錯,大小寫,變數名寫錯,括號不全。

3. 對cuda()的用法有了進一步認識;

4. 學會了除錯程式(fire);

5. 學會了訓練結果的視覺化(visdom);

6. 進一步的瞭解了LSTM,對深度學習的架構、實現有了巨集觀把控。