為你寫詩(LSTM 詩歌生成器)
阿新 • • 發佈:2019-02-19
為你寫詩,為你靜止 ,為你做不可能的事。
為你寫詩,你這不是為難我們直男直女的程式設計師們嘛。
雖然我寫不出詩,但不代表我不能訓練一個網路為你寫詩,想要多少寫多少!
所以今天的主題就是如何訓練一個能自動寫詩的LSTM模型。
廢話不多說,程式碼如下:
愛情是一種怪事
我開始全身不受控制
愛情是一種本事
我開始連自己都不是
為你我做了太多的傻事
第一件就是為你寫詩
為你寫詩為你靜止
為你做不可能的事
為你我學會彈琴寫詞
為你失去理智
為你寫詩為你靜止
為你做不可能的事
為你彈奏所有情歌的句子
我忘了說最美的是你的名字
呃呃呃好像不小心貼錯程式碼了,重來:
# coding: utf-8 # In[1]: import re import random import pandas as pd import numpy as np from keras.preprocessing import sequence from keras.optimizers import SGD, RMSprop, Adagrad from keras.utils import np_utils from keras.models import Sequential from keras.layers.core import Dense, Dropout, Activation from keras.layers.embeddings import Embedding from keras.layers.recurrent import LSTM, GRU # In[2]: # 讀取資料, 生成漢字列表 with open('poetry.txt','r', encoding='UTF-8') as f: raw_text = f.read() lines = raw_text.split("\n")[:-1] poem_text = [i.split(':')[1] for i in lines] char_list = [re.findall('[\x80-\xff]{3}|[\w\W]', s) for s in poem_text] # In[3]: # 漢字 <-> 數字 對映 all_words = [] for i in char_list: all_words.extend(i) word_dataframe = pd.DataFrame(pd.Series(all_words).value_counts()) word_dataframe['id'] = list(range(1,len(word_dataframe)+1)) word_index_dict = word_dataframe['id'].to_dict() index_dict = {} for k in word_index_dict: index_dict.update({word_index_dict[k]:k}) len(all_words), len(word_dataframe), len(index_dict) # In[4]: # 生成訓練資料, x 為 前兩個漢字, y 為 接下來的漢字 # 如: 明月幾時有 會被整理成下面三條資料 # 明月 -> 幾 月幾 -> 時 幾時 -> 有 seq_len = 2 dataX = [] dataY = [] for i in range(0, len(all_words) - seq_len, 1): seq_in = all_words[i : i + seq_len] seq_out = all_words[i + seq_len] dataX.append([word_index_dict[x] for x in seq_in]) dataY.append(word_index_dict[seq_out]) len(dataY) # In[5]: X = np.array(dataX) y = np_utils.to_categorical(np.array(dataY)) X.shape, y.shape # In[6]: model = Sequential() # Embedding 層將正整數(下標)轉換為具有固定大小的向量,如[[4],[20]]->[[0.25,0.1],[0.6,-0.2]] # Embedding 層只能作為模型的第一層 # input_dim:大或等於0的整數,字典長度 # output_dim:大於0的整數,代表全連線嵌入的維度 model.add(Embedding(len(word_dataframe), 512)) # LSTM model.add(LSTM(512)) # Dropout 防止過擬合 model.add(Dropout(0.5)) # output 為 y 的維度 model.add(Dense(y.shape[1])) model.add(Activation('softmax')) model.compile(loss='categorical_crossentropy', optimizer='adam') model.summary() # In[7]: # 訓練 model.fit(X, y, batch_size=64, epochs=40) # In[8]: def get_predict_array(seed_text): chars = re.findall('[\x80-\xff]{3}|[\w\W]', seed_text) x = np.array([word_index_dict[k] for k in chars]) proba = model.predict(x, verbose=0) return proba get_predict_array("明月") # 可以看到預測出來的結果是兩個列表, 下一個字是第二個列表 # In[9]: def gen_poetry(model, seed_text, rows=4, cols=5): ''' 生成詩詞的函式 輸入: 兩個漢字, 行數, 每行的字數 (預設為五言絕句) ''' total_cols = cols + 1 # 加上標點符號 chars = re.findall('[\x80-\xff]{3}|[\w\W]', seed_text) if len(chars) != seq_len: # seq_len = 2 return "" arr = [word_index_dict[k] for k in chars] for i in range(seq_len, rows * total_cols): if (i+1) % total_cols == 0: # 逗號或句號 if (i+1) / total_cols == 2 or (i+1) / total_cols == 4: # 句號的情況 arr.append(2) # 句號在字典中的對映為 2 else: arr.append(1) # 逗號在字典中的對映為 1 else: proba = model.predict(np.array(arr[-seq_len:]), verbose=0) predicted = np.argsort(proba[1])[-5:] index = random.randint(0,len(predicted)-1) # 在前五個可能結果裡隨機取, 避免每次都是同樣的結果 new_char = predicted[index] while new_char == 1 or new_char == 2: # 如果是逗號或句號, 應該重新換一個 index = random.randint(0,len(predicted)-1) new_char = predicted[index] arr.append(new_char) poem = [index_dict[i] for i in arr] return "".join(poem) # In[10]: print(gen_poetry(model, '明月')) print(gen_poetry(model, '悠然', rows=4, cols=7)) print(gen_poetry(model, '長河', rows=4, cols=7)) # In[11]: model.save(filepath='lstm_poetry.hdf5') # In[12]: # 試下 GRU gru = Sequential() gru.add(Embedding(len(word_dataframe), 512)) gru.add(GRU(512)) # gru.add(Dropout(0.5)) gru.add(Dense(y.shape[1])) gru.add(Activation('softmax')) gru.compile(loss='categorical_crossentropy', optimizer='adam') # In[13]: gru.summary() # In[14]: gru.fit(X, y, batch_size=64, epochs=40) # In[15]: print(gen_poetry(gru, '明月')) print(gen_poetry(gru, '悠然', rows=4, cols=7)) print(gen_poetry(gru, '長河', rows=4, cols=7)) # In[16]: gru.save('gru_poetry.hdf5')
祝大家都能找到可以為她(他)寫詩的人!