rakepants's picture
Remove Inference API from model card
f97673e verified
|
raw
history blame
2.34 kB
---
license: mit
language:
- ru
library_name: transformers
tags:
- gpt2
- conversational
- not-for-all-audiences
base_model: tinkoff-ai/ruDialoGPT-medium
inference: false
widget:
- text: "@@ПЕРВЫЙ@@Привет, как дела?@@ВТОРОЙ@@"
example_title: "Greet"
- text: "@@ПЕРВЫЙ@@Ты нормальный вообще?@@ВТОРОЙ@@"
example_title: "Confront"
---
This is a toxic conversational model based on [tinkoff-ai/ruDialoGPT-medium](https://huggingface.co/tinkoff-ai/ruDialoGPT-medium).
## Model training
We've created a custom dataset out of [raw imageboard dialogue data](https://github.com/Koziev/NLP_Datasets/tree/master/Conversations/Data).
The data processing notebook is available [here](https://github.com/RakePants/nerdless/blob/main/notebooks/dataset.ipynb).
The model was finetuned on a chunk of the dataset of size 350,000 samples with the following parameters:
```
learning_rate=4e-7,
num_train_epochs=1,
per_device_train_batch_size=24,
per_device_eval_batch_size=24,
warmup_steps=100,
gradient_accumulation_steps=16,
fp16=True
```
The finetuning notebook is available [here](https://github.com/RakePants/nerdless/blob/main/notebooks/finetuning.ipynb).
## Inference
You can utilize Better Transformers for faster inference.
The model can be inferenced as follows:
```
from optimum.bettertransformer import BetterTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer
checkpoint = "rakepants/ruDialoGPT-medium-finetuned-toxic"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model_hf = AutoModelForCausalLM.from_pretrained(checkpoint)
model = BetterTransformer.transform(model_hf, keep_original_model=False)
# token id 50257 - @@ПЕРВЫЙ@@
# token id 50258 - @@ВТОРОЙ@@
input = "@@ПЕРВЫЙ@@Привет, как дела?@@ВТОРОЙ@@"
inputs = tokenizer(input, return_tensors='pt')
generated_token_ids = model.generate(
**inputs,
top_k=10,
top_p=0.95,
num_beams=3,
num_return_sequences=1,
do_sample=True,
no_repeat_ngram_size=2,
temperature=0.7,
repetition_penalty=1.2,
length_penalty=1.0,
early_stopping=True,
max_new_tokens=48,
eos_token_id=50257,
pad_token_id=0
)
context_with_response = [tokenizer.decode(sample_token_ids) for sample_token_ids in generated_token_ids]
```