llama2-qrecc / README.md
Ashishkr's picture
Update README.md
897e485
metadata
tags:
  - autotrain
  - text-generation

Model Trained Using AutoTrain

This model was trained using AutoTrain. For more information, please visit AutoTrain.

Usage



from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
import torch
import re
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = PeftConfig.from_pretrained("Ashishkr/llama2-qrecc")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
model = PeftModel.from_pretrained(model, "Ashishkr/llama2-qrecc").to(device)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

def response_generate(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    prompt: str,
    max_new_tokens: int = 128,
    temperature: float = 0.7,
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    inputs = tokenizer(
        [prompt],
        return_tensors="pt",
        return_token_type_ids=False,
    ).to(
        device
    )

    with torch.autocast("cuda", dtype=torch.bfloat16):
        response = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            return_dict_in_generate=True,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
        )

    decoded_output = tokenizer.decode(
        response["sequences"][0],
        skip_special_tokens=True,
    )

    return decoded_output

prompt = """>>CONTEXT<<I heard John Marks was the first christian missionary in Ireland. What was the capital then??>>REWRITE<< """

response = response_generate(
    model,
    tokenizer,
    prompt,
    max_new_tokens=20,
    temperature=0.1,
)

def extract_between_tags(input_string):
    pattern = r'>>REWRITE<<(.*?)</REWRITE>'
    match = re.search(pattern, input_string)
    return match.group(1) if match else ''

print(extract_between_tags(response))