|
--- |
|
license: mit |
|
language: |
|
- ru |
|
library_name: transformers |
|
tags: |
|
- gpt2 |
|
- conversational |
|
- not-for-all-audiences |
|
base_model: tinkoff-ai/ruDialoGPT-medium |
|
inference: false |
|
widget: |
|
- text: "@@ПЕРВЫЙ@@Привет, как дела?@@ВТОРОЙ@@" |
|
example_title: "Greet" |
|
- text: "@@ПЕРВЫЙ@@Ты нормальный вообще?@@ВТОРОЙ@@" |
|
example_title: "Confront" |
|
--- |
|
This is a toxic conversational model based on [tinkoff-ai/ruDialoGPT-medium](https://huggingface.co/tinkoff-ai/ruDialoGPT-medium). |
|
|
|
## Model training |
|
We've created a custom dataset out of [raw imageboard dialogue data](https://github.com/Koziev/NLP_Datasets/tree/master/Conversations/Data). |
|
The data processing notebook is available [here](https://github.com/RakePants/nerdless/blob/main/notebooks/dataset.ipynb). |
|
|
|
The model was finetuned on a chunk of the dataset of size 350,000 samples with the following parameters: |
|
``` |
|
learning_rate=4e-7, |
|
num_train_epochs=1, |
|
per_device_train_batch_size=24, |
|
per_device_eval_batch_size=24, |
|
warmup_steps=100, |
|
gradient_accumulation_steps=16, |
|
fp16=True |
|
``` |
|
|
|
The finetuning notebook is available [here](https://github.com/RakePants/nerdless/blob/main/notebooks/finetuning.ipynb). |
|
## Inference |
|
You can utilize Better Transformers for faster inference. |
|
|
|
The model can be inferenced as follows: |
|
``` |
|
from optimum.bettertransformer import BetterTransformer |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
checkpoint = "rakepants/ruDialoGPT-medium-finetuned-toxic" |
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint) |
|
model_hf = AutoModelForCausalLM.from_pretrained(checkpoint) |
|
model = BetterTransformer.transform(model_hf, keep_original_model=False) |
|
|
|
# token id 50257 - @@ПЕРВЫЙ@@ |
|
# token id 50258 - @@ВТОРОЙ@@ |
|
|
|
input = "@@ПЕРВЫЙ@@Привет, как дела?@@ВТОРОЙ@@" |
|
inputs = tokenizer(input, return_tensors='pt') |
|
|
|
generated_token_ids = model.generate( |
|
**inputs, |
|
top_k=10, |
|
top_p=0.95, |
|
num_beams=3, |
|
num_return_sequences=1, |
|
do_sample=True, |
|
no_repeat_ngram_size=2, |
|
temperature=0.7, |
|
repetition_penalty=1.2, |
|
length_penalty=1.0, |
|
early_stopping=True, |
|
max_new_tokens=48, |
|
eos_token_id=50257, |
|
pad_token_id=0 |
|
) |
|
|
|
context_with_response = [tokenizer.decode(sample_token_ids) for sample_token_ids in generated_token_ids] |
|
``` |