Japanese GPT2 Lyric Model

Model description

The model is used to generate Japanese lyrics.

You can try it on my website https://lyric.fab.moe/

How to use

import torch
from transformers import T5Tokenizer, GPT2LMHeadModel

tokenizer = T5Tokenizer.from_pretrained("skytnt/gpt2-japanese-lyric-small")
model = GPT2LMHeadModel.from_pretrained("skytnt/gpt2-japanese-lyric-small")


def gen_lyric(prompt_text: str):
    prompt_text = "<s>" + prompt_text.replace("\n", "\\n ")
    prompt_tokens = tokenizer.tokenize(prompt_text)
    prompt_token_ids = tokenizer.convert_tokens_to_ids(prompt_tokens)
    prompt_tensor = torch.LongTensor(prompt_token_ids).to(device)
    prompt_tensor = prompt_tensor.view(1, -1)
    # model forward
    output_sequences = model.generate(
        input_ids=prompt_tensor,
        max_length=512,
        top_p=0.95,
        top_k=40,
        temperature=1.0,
        do_sample=True,
        early_stopping=True,
        bos_token_id=tokenizer.bos_token_id,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
        num_return_sequences=1
    )

    # convert model outputs to readable sentence
    generated_sequence = output_sequences.tolist()[0]
    generated_tokens = tokenizer.convert_ids_to_tokens(generated_sequence)
    generated_text = tokenizer.convert_tokens_to_string(generated_tokens)
    generated_text = "\n".join([s.strip() for s in generated_text.split('\\n')]).replace(' ', '\u3000').replace('<s>', '').replace('</s>', '\n\n---end---')
    return generated_text


print(gen_lyric("桜が咲く"))

Training data

Training data contains 143,587 Japanese lyrics which are collected from uta-net by lyric_download

Downloads last month
636
Safetensors
Model size
123M params
Tensor type
F32
·
U8
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Dataset used to train skytnt/gpt2-japanese-lyric-small