rakepants's picture
Remove Inference API from model card
f97673e verified
|
raw
history blame
2.34 kB
metadata
license: mit
language:
  - ru
library_name: transformers
tags:
  - gpt2
  - conversational
  - not-for-all-audiences
base_model: tinkoff-ai/ruDialoGPT-medium
inference: false
widget:
  - text: '@@ПЕРВЫЙ@@Привет, как дела?@@ВТОРОЙ@@'
    example_title: Greet
  - text: '@@ПЕРВЫЙ@@Ты нормальный вообще?@@ВТОРОЙ@@'
    example_title: Confront

This is a toxic conversational model based on tinkoff-ai/ruDialoGPT-medium.

Model training

We've created a custom dataset out of raw imageboard dialogue data.
The data processing notebook is available here.

The model was finetuned on a chunk of the dataset of size 350,000 samples with the following parameters:

learning_rate=4e-7,
num_train_epochs=1, 
per_device_train_batch_size=24,
per_device_eval_batch_size=24,
warmup_steps=100,
gradient_accumulation_steps=16,
fp16=True

The finetuning notebook is available here.

Inference

You can utilize Better Transformers for faster inference.

The model can be inferenced as follows:

from optimum.bettertransformer import BetterTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer

checkpoint = "rakepants/ruDialoGPT-medium-finetuned-toxic"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model_hf = AutoModelForCausalLM.from_pretrained(checkpoint)
model = BetterTransformer.transform(model_hf, keep_original_model=False)

# token id 50257 - @@ПЕРВЫЙ@@
# token id 50258 - @@ВТОРОЙ@@

input = "@@ПЕРВЫЙ@@Привет, как дела?@@ВТОРОЙ@@"  
inputs = tokenizer(input, return_tensors='pt')

generated_token_ids = model.generate(
    **inputs,
    top_k=10,
    top_p=0.95,
    num_beams=3,
    num_return_sequences=1,
    do_sample=True,
    no_repeat_ngram_size=2,
    temperature=0.7,
    repetition_penalty=1.2,
    length_penalty=1.0,
    early_stopping=True,
    max_new_tokens=48,
    eos_token_id=50257,
    pad_token_id=0
)

context_with_response = [tokenizer.decode(sample_token_ids) for sample_token_ids in generated_token_ids]