|
|
|
import random |
|
import jieba |
|
import collections |
|
|
|
|
|
class NGramModel: |
|
|
|
|
|
def __init__(self, n, alpha): |
|
self.n = n |
|
self.alpha = alpha |
|
self.ngrams = collections.defaultdict(int) |
|
self.contexts = collections.defaultdict(int) |
|
self.vocabulary = set() |
|
|
|
|
|
def train(self, corpus): |
|
for sentence in corpus: |
|
|
|
sentence = ["<s>"] * (self.n - 1) + sentence + ["</s>"] |
|
|
|
for i in range(len(sentence) - self.n + 1): |
|
ngram = tuple(sentence[i:i+self.n]) |
|
self.ngrams[ngram] += 1 |
|
|
|
context = tuple(sentence[i:i+self.n-1]) |
|
self.contexts[context] += 1 |
|
|
|
self.vocabulary.update(ngram) |
|
|
|
|
|
def predict(self, context): |
|
|
|
probabilities = {} |
|
|
|
for word in self.vocabulary: |
|
|
|
ngram = tuple(context) + (word,) |
|
|
|
probability = (self.ngrams[ngram] + self.alpha) / (self.contexts[tuple(context)] + self.alpha * len(self.vocabulary)) |
|
|
|
probabilities[word] = probability |
|
|
|
return probabilities |
|
|
|
|
|
def sample(self, probabilities): |
|
|
|
total = sum(probabilities.values()) |
|
|
|
random_number = random.uniform(0, total) |
|
|
|
cumulative_probability = 0.0 |
|
|
|
for word, probability in probabilities.items(): |
|
|
|
cumulative_probability += probability |
|
|
|
if cumulative_probability >= random_number: |
|
return word |
|
|
|
|
|
def generate(self, context): |
|
|
|
sentence = list(context) |
|
|
|
while True: |
|
|
|
probabilities = self.predict(context) |
|
|
|
word = self.sample(probabilities) |
|
|
|
sentence.append(word) |
|
|
|
if word == "</s>": |
|
break |
|
|
|
context = context[1:] + (word,) |
|
|
|
return sentence[self.n-1:-1] |
|
|
|
|
|
corpus = [] |
|
with open("corpus.txt", encoding="utf-8") as f: |
|
for line in f: |
|
line = line.strip() |
|
if line: |
|
words = list(jieba.cut(line)) |
|
corpus.append(words) |
|
print("语料库中的句子数:", len(corpus)) |
|
print(corpus) |
|
|
|
model = NGramModel(3, 0.01) |
|
|
|
model.train(corpus) |
|
print("词汇表中的词数:", len(model.vocabulary)) |
|
print("n-1-gram的计数:", model.contexts.items()) |
|
|
|
|
|
sentence = model.generate(("我", "爱")) |
|
|
|
print("".join(sentence)) |
|
sentence = model.generate(("我",)) |
|
print("".join(sentence)) |
|
|