File size: 4,207 Bytes
917b125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import re
import tenacity
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel


class LLM:
    def __init__(self, model_id="Qwen/Qwen2.5-7B-Instruct",):

        self.model_id = model_id
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        # Load the model and tokenizer based on the model_id
        if "meta-llama" in self.model_id:
            self.tokenizer = AutoTokenizer.from_pretrained(model_id)
            self.model = AutoModelForCausalLM.from_pretrained(
                model_id,
                torch_dtype=torch.bfloat16,
                device_map="auto"
            )

        elif "InternVL" in self.model_id:
            self.model = AutoModel.from_pretrained(
                model_id,
                torch_dtype=torch.bfloat16,
                low_cpu_mem_usage=True,
                trust_remote_code=True,
                device_map="auto"
            ).eval()

            self.tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, use_fast=False)

        else:
            self.model = AutoModelForCausalLM.from_pretrained(
                model_id,
                torch_dtype="auto",
                device_map="auto"
            )

            self.tokenizer = AutoTokenizer.from_pretrained(model_id)

    @torch.no_grad()
    def generate(self, query):
        if "meta-llama" in self.model_id:
            messages = [
                {"role": "user", "content": [
                    {"type": "text", "text": f"{query}"}
                ]}
            ]
            text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            model_inputs = self.tokenizer([text], return_tensors="pt").to(self.device)
            generated_ids = self.model.generate(model_inputs.input_ids, max_new_tokens=512)
            generated_ids = [output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)]

            response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
        elif "InternVL" in self.model_id:
            generation_config = dict(max_new_tokens=1024, do_sample=True)
            response = self.model.chat(self.tokenizer, None, query, generation_config, history=None, return_history=False)
        else:
            messages = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": query}]
            text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            model_inputs = self.tokenizer([text], return_tensors="pt").to(self.device)

            generated_ids = self.model.generate(model_inputs.input_ids, max_new_tokens=512)
            generated_ids = [output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)]

            response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
        return response

    @tenacity.retry(stop=tenacity.stop_after_delay(10))
    def answer(self, query, objects):
        query = f"""
        Extract the object that satisfies the intent of the query or determine the tool that aligns with the purpose of {query}.
        pick the best option from the following: {', '.join(objects)},
        Please return a list of all suitable options as long as they make sense in the format of a Python list in the following format: ```python\n['option1', 'option2', ...]```"""
        res = self.generate(query)
        match = re.search(r"`{3}python\\n(.*)`{3}", res, re.DOTALL)
        if match:
            res = match.group(1)
            res = [r.translate(str.maketrans("", "", "_-")) for r in eval(res)]
            return res
        else:
            # Try to extract content directly from brackets []
            match_brackets = re.search(r"\[(.*?)\]", res, re.DOTALL)
            if match_brackets:
                res = match_brackets.group(0)  # Include brackets for eval
                res = [r.translate(str.maketrans("", "", "_-")) for r in eval(res)]
                return res
            else:
                raise ValueError(f"Failed to parse response: {res}")