Files
flask_lstm_poem_generator/generate.py
T
2023-05-12 18:45:02 +08:00

113 lines
3.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import re
import numpy as np
import torch as t
from model import PoetryModel
from config import Config
def gen(start_words,action):
print("正在初始化......")
datas = np.load("tang.npz", allow_pickle=True)
data = datas['data']
ix2word = datas['ix2word'].item()
word2ix = datas['word2ix'].item()
model = PoetryModel(len(ix2word), Config.embedding_dim, Config.hidden_dim)
model.load_state_dict(t.load(Config.model_path, 'cpu'))
if Config.use_gpu:
model.to(t.device('cuda'))
print("初始化完成!\n")
if action == 1:
gen_poetry = ''.join(generate(model, start_words, ix2word, word2ix))
arr = gen_poetry.split('')
else:
gen_poetry = ''.join(gen_acrostic(model, start_words, ix2word, word2ix))
# 藏头诗根据,。分割
arr = re.split('|。', gen_poetry)
# 将输出诗歌排列整齐
# arr = [i+"。" for i in arr if len(i) > 0]
print(arr)
# rs = []
# print(rs)
return arr
# 给定首句生成诗歌
def generate(model, start_words, ix2word, word2ix, prefix_words=None):
results = list(start_words)
start_words_len = len(start_words)
# 第一个词语是<START>
input = t.Tensor([word2ix['<START>']]).view(1, 1).long()
if Config.use_gpu:
input = input.cuda()
hidden = None
# 若有风格前缀,则先用风格前缀生成hidden
if prefix_words:
# 第一个input是<START>,后面就是prefix中的汉字
# 第一个hidden是None,后面就是前面生成的hidden
for word in prefix_words:
output, hidden = model(input, hidden)
input = input.data.new([word2ix[word]]).view(1, 1)
# 开始真正生成诗句,如果没有使用风格前缀,则hidden = Noneinput = <START>
# 否则,input就是风格前缀的最后一个词语,hidden也是生成出来的
for i in range(Config.max_gen_len):
output, hidden = model(input, hidden)
# print(output.shape)
# 如果还在诗句内部,输入就是诗句的字,不取出结果,只为了得到
# 最后的hidden
if i < start_words_len:
w = results[i]
input = input.data.new([word2ix[w]]).view(1, 1)
# 否则将output作为下一个input进行
else:
# print(output.data[0].topk(1))
top_index = output.data[0].topk(1)[1][0].item()
w = ix2word[top_index]
results.append(w)
input = input.data.new([top_index]).view(1, 1)
if w == '<EOP>':
del results[-1]
break
return results
# 生成藏头诗
def gen_acrostic(model, start_words, ix2word, word2ix, prefix_words=None):
result = []
start_words_len = len(start_words)
input = (t.Tensor([word2ix['<START>']]).view(1, 1).long())
if Config.use_gpu:
input = input.cuda()
# 指示已经生成了几句藏头诗
index = 0
pre_word = '<START>'
hidden = None
# 存在风格前缀,则生成hidden
if prefix_words:
for word in prefix_words:
output, hidden = model(input, hidden)
input = (input.data.new([word2ix[word]])).view(1, 1)
# 开始生成诗句
for i in range(Config.max_gen_len):
output, hidden = model(input, hidden)
top_index = output.data[0].topk(1)[1][0].item()
w = ix2word[top_index]
# 说明上个字是句末
if pre_word in {'', '', '?', '', '<START>'}:
if index == start_words_len:
break
else:
w = start_words[index]
index += 1
# print(w,word2ix[w])
input = (input.data.new([word2ix[w]])).view(1, 1)
else:
input = (input.data.new([top_index])).view(1, 1)
result.append(w)
pre_word = w
return result