In traditional language model, such as RNN, ,
In bidirectional language model, it has larger context, .
In this implementation, we simply adopt the following approximation,
export BERT_BASE_DIR=model/chinese_L-12_H-768_A-12
export INPUT_FILE=data/lm/poetry2.tsv
python run_lm_predict.py \
--input_file=$INPUT_FILE \
--vocab_file=$BERT_BASE_DIR/vocab.txt \
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
--init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
--max_seq_length=128 \
--output_dir=./tmp/lm_output/$ cat /tmp/lm/output/test_result.json
'''BERT用作语言模型,计算句子分数,检验句子的合理性与否,
其实类似于基于bert-mlm的中文纠错,每个字符作为mask计算一个loss'''
from torch.multiprocessing import TimeoutError, Pool, set_start_method, Queue
import torch.multiprocessing as mp
import torch
import numpy as np
# from transformers import DistilBertTokenizer,DistilBertForMaskedLM
from pytorch_pretrained_bert import BertTokenizer, BertForMaskedLM
import json, math
try:
set_start_method('spawn')
except RuntimeError:
pass
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model():
## 加载bert模型,这个路径文件夹下有bert_config.json配置文件和model.bin模型权重文件
# bert-base-uncased是英文的
model = BertForMaskedLM.from_pretrained('bert-base-chinese').to(device)
model.eval()
## 加载bert的分词器
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
return tokenizer, model
tokenizer, model = load_model()
'''
将loss作为句子困惑度ppl的分数:
不足:
1. 给每个word打分,都要跑一遍inference,计算量较大,且冗余。有优化的空间
2.该实现中采用的句子概率是近似概率,不够严谨
'''
def get_score(sentence):
tokenize_input = tokenizer.tokenize(sentence)
tensor_input = torch.tensor([tokenizer.convert_tokens_to_ids(tokenize_input)])
# Predict all tokens
predictions = model(tensor_input) # model(masked_ids)
#nn.CrossEntropyLoss(size_average=False)
# 根据pytorch的官方文档,size_average默认情况下是True,对每个小批次的损失取平均值。 但是,如果字段size_average设置为False,则每个小批次的损失将被相加。如果参数reduce = False,则忽略
loss_fct = torch.nn.CrossEntropyLoss()
loss = loss_fct(predictions.squeeze(), tensor_input.squeeze()).data#已经取平均值后的loss,作为句子的ppl分数返回
return math.exp(loss)
print(get_score("杜甫是什么的诗词是有哪些"))
print(get_score("杜甫的诗词有哪些"))