113 lines
3.8 KiB
Python
113 lines
3.8 KiB
Python
import re
|
||
import numpy as np
|
||
import torch as t
|
||
from model import PoetryModel
|
||
from config import Config
|
||
|
||
|
||
def gen(start_words,action):
|
||
print("正在初始化......")
|
||
datas = np.load("tang.npz", allow_pickle=True)
|
||
data = datas['data']
|
||
ix2word = datas['ix2word'].item()
|
||
word2ix = datas['word2ix'].item()
|
||
model = PoetryModel(len(ix2word), Config.embedding_dim, Config.hidden_dim)
|
||
model.load_state_dict(t.load(Config.model_path, 'cpu'))
|
||
if Config.use_gpu:
|
||
model.to(t.device('cuda'))
|
||
print("初始化完成!\n")
|
||
if action == 1:
|
||
gen_poetry = ''.join(generate(model, start_words, ix2word, word2ix))
|
||
arr = gen_poetry.split('。')
|
||
else:
|
||
gen_poetry = ''.join(gen_acrostic(model, start_words, ix2word, word2ix))
|
||
# 藏头诗根据,。分割
|
||
arr = re.split(',|。', gen_poetry)
|
||
# 将输出诗歌排列整齐
|
||
# arr = [i+"。" for i in arr if len(i) > 0]
|
||
print(arr)
|
||
# rs = []
|
||
# print(rs)
|
||
|
||
return arr
|
||
|
||
|
||
# 给定首句生成诗歌
|
||
def generate(model, start_words, ix2word, word2ix, prefix_words=None):
|
||
results = list(start_words)
|
||
start_words_len = len(start_words)
|
||
# 第一个词语是<START>
|
||
input = t.Tensor([word2ix['<START>']]).view(1, 1).long()
|
||
if Config.use_gpu:
|
||
input = input.cuda()
|
||
hidden = None
|
||
|
||
# 若有风格前缀,则先用风格前缀生成hidden
|
||
if prefix_words:
|
||
# 第一个input是<START>,后面就是prefix中的汉字
|
||
# 第一个hidden是None,后面就是前面生成的hidden
|
||
for word in prefix_words:
|
||
output, hidden = model(input, hidden)
|
||
input = input.data.new([word2ix[word]]).view(1, 1)
|
||
|
||
# 开始真正生成诗句,如果没有使用风格前缀,则hidden = None,input = <START>
|
||
# 否则,input就是风格前缀的最后一个词语,hidden也是生成出来的
|
||
for i in range(Config.max_gen_len):
|
||
output, hidden = model(input, hidden)
|
||
# print(output.shape)
|
||
# 如果还在诗句内部,输入就是诗句的字,不取出结果,只为了得到
|
||
# 最后的hidden
|
||
if i < start_words_len:
|
||
w = results[i]
|
||
input = input.data.new([word2ix[w]]).view(1, 1)
|
||
# 否则将output作为下一个input进行
|
||
else:
|
||
# print(output.data[0].topk(1))
|
||
top_index = output.data[0].topk(1)[1][0].item()
|
||
w = ix2word[top_index]
|
||
results.append(w)
|
||
input = input.data.new([top_index]).view(1, 1)
|
||
if w == '<EOP>':
|
||
del results[-1]
|
||
break
|
||
return results
|
||
|
||
|
||
# 生成藏头诗
|
||
def gen_acrostic(model, start_words, ix2word, word2ix, prefix_words=None):
|
||
result = []
|
||
start_words_len = len(start_words)
|
||
input = (t.Tensor([word2ix['<START>']]).view(1, 1).long())
|
||
if Config.use_gpu:
|
||
input = input.cuda()
|
||
# 指示已经生成了几句藏头诗
|
||
index = 0
|
||
pre_word = '<START>'
|
||
hidden = None
|
||
|
||
# 存在风格前缀,则生成hidden
|
||
if prefix_words:
|
||
for word in prefix_words:
|
||
output, hidden = model(input, hidden)
|
||
input = (input.data.new([word2ix[word]])).view(1, 1)
|
||
|
||
# 开始生成诗句
|
||
for i in range(Config.max_gen_len):
|
||
output, hidden = model(input, hidden)
|
||
top_index = output.data[0].topk(1)[1][0].item()
|
||
w = ix2word[top_index]
|
||
# 说明上个字是句末
|
||
if pre_word in {'。', ',', '?', '!', '<START>'}:
|
||
if index == start_words_len:
|
||
break
|
||
else:
|
||
w = start_words[index]
|
||
index += 1
|
||
# print(w,word2ix[w])
|
||
input = (input.data.new([word2ix[w]])).view(1, 1)
|
||
else:
|
||
input = (input.data.new([top_index])).view(1, 1)
|
||
result.append(w)
|
||
pre_word = w
|
||
return result
|