|
--- |
|
language: ru |
|
license: cc-by-nc-4.0 |
|
tags: |
|
- paraphrasing |
|
- seq2seq |
|
datasets: |
|
- inkoziev/paraphrases |
|
--- |
|
|
|
## Поэтический перефразировщик |
|
|
|
Это генеративная модель на основе ```sberbank-ai/rugpt3large_based_on_gpt2```, дообученной |
|
на датасете перефразировок [inkoziev/paraphrases](https://huggingface.co/datasets/inkoziev/paraphrases). |
|
Она разработана для использования в проекте [генеративной поэзии](https://github.com/Koziev/verslibre). |
|
Код для тренировки и использования перефразировщика доступен в репозитрии [https://github.com/Koziev/paraphraser](https://github.com/Koziev/paraphraser). |
|
|
|
|
|
### Особенности перефразировки |
|
|
|
Обращаю внимание, что модель **не предназначена** для использования там, где требуется |
|
особо аккуратная работа с именованными сущностями. Так как в стихах не возникает особых проблем (более того, |
|
в некоторых сценариях использования это даже желательно), если перефразировки теряют или добавляют некоторую семантику в исходный текст, то обучающий датасет |
|
и модель на его основе может путать дни недели, имена, добавлять что-то от себя, быть метафоричной или иносказательной. |
|
|
|
### Методика файнтюна |
|
|
|
В обучающем датасете есть негативные примеры перефразировок, и я использую их вместе с правильными примерами в ходе файнтюна, |
|
подавая на классификационную голову в [GPT2DoubleHeadsModel](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2DoubleHeadsModel). |
|
Код, выполняющий файнтюн, доступен [тут](https://github.com/Koziev/paraphraser/blob/main/train_paraphraser_with_gpt2doublehead.py). |
|
|
|
Такой подход к файнтюну оказался лучше, чем два других подхода: |
|
|
|
1) дефолтный способ файнтюна, когда GPT дообучается просто на текстах, состоящих из исходного текста и перефразировки, |
|
разделенных специальным токеном. В этом подходе модель обучается также на токенах затравки, что может быть нежелательным. |
|
2) вариация первого способа, в котором токены затравки (исходного текста) исключаются из обратного распространения с помощью |
|
задания labels=-100 ([код](https://github.com/Koziev/paraphraser/blob/main/finetune_paraphraser_with_prompt_masking.py)). |
|
|
|
В качестве метрики для сравнения подходов и для подбора числа неверных вариантов перефразировки в GPT2DoubleHeadsModel |
|
использована комбинация из: |
|
1) близость векторов эмбеддингов исходного текста и сгенерированной перефразировки. Векторы получаются с помощью |
|
модели ```sberbank-ai/sbert_large_mt_nlu_ru```. Я не стал использовать [модель-критик](https://huggingface.co/inkoziev/sbert_synonymy), |
|
поскольку она обучалась на таком же датасете. |
|
2) дисконтируем результаты п.1 символьной близостью (3-граммы) по коэффициенту Жаккара. Это штрафует перестановочные |
|
перефразировки, воспроизведение исходного текста и небольшие переписывания. |
|
|
|
### Формат входных данных |
|
|
|
На вход модели подается исходный текст с добавлением токенов ```<s>``` в начале и ```<sep>``` в конце, например: |
|
|
|
``` |
|
input_text = '<s>Мороз и солнце, день чудесный<sep>' |
|
``` |
|
|
|
Результат генерации будет содержать текст с токеном ```</s>``` - это конец последовательности. |
|
|
|
### Пример использования |
|
|
|
Следующий код позволяет ввести в консоли короткое предложение |
|
и видеть результат ее перефразировки моделью: |
|
``` |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
model_name = "inkoziev/paraphraser" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
model.to(device) |
|
model.eval() |
|
|
|
while True: |
|
seed = input(':> ').strip() |
|
encoded_prompt = tokenizer.encode("<s>" + seed + "<sep>", add_special_tokens=False, return_tensors="pt").to(device) |
|
output_sequences = model.generate(input_ids=encoded_prompt, |
|
max_length=100, |
|
typical_p=0.85, |
|
top_k=0, |
|
top_p=1.0, |
|
do_sample=True, |
|
num_return_sequences=10, |
|
pad_token_id=tokenizer.pad_token_id) |
|
|
|
for o in output_sequences: |
|
text = tokenizer.decode(o.tolist(), clean_up_tokenization_spaces=True) |
|
text = text[text.index('<sep>') + 5:] |
|
text = text[: text.find('</s>')] |
|
print(text) |
|
``` |
|
|