NotYuSheng commited on
Commit
510f333
1 Parent(s): 18f83a2

Create multimodal_ai.py

Browse files
Files changed (1) hide show
  1. multimodal_ai.py +47 -0
multimodal_ai.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
4
+ from huggingface_hub import HfApi, login
5
+
6
+ class MultimodalAI:
7
+ def __init__(self):
8
+ # Obtain Hugging Face token in .env file
9
+ self.HUGGINGFACE_TOKEN = os.environ["HUGGINGFACE_TOKEN"]
10
+
11
+ # Check if the token is retrieved successfully
12
+ if self.HUGGINGFACE_TOKEN is None:
13
+ raise ValueError("HUGGINGFACE_TOKEN environment variable is not set.")
14
+
15
+ # Authenticate with Hugging Face
16
+ self.api = HfApi()
17
+ login(token=self.HUGGINGFACE_TOKEN)
18
+
19
+ # Model selection
20
+ self.model_name = "openai-community/gpt2"
21
+
22
+ # Check if a CUDA-enabled GPU is available.
23
+ # If available, move the model to the GPU (cuda:0) for faster computation.
24
+ # Otherwise, move the model to the CPU.
25
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
26
+
27
+ # Load the model and tokenizer
28
+ self._load_model_and_tokenizer()
29
+
30
+ def _load_model_and_tokenizer(self):
31
+ # Load LLama model and tokenizer
32
+ self.model = AutoModelForCausalLM.from_pretrained(self.model_name,
33
+ token=self.HUGGINGFACE_TOKEN).to(self.device)
34
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name,
35
+ token=self.HUGGINGFACE_TOKEN)
36
+
37
+ def generate_response(self, text_input, max_new_tokens=50):
38
+ # Tokenize input text
39
+ inputs = self.tokenizer(text_input, return_tensors="pt").to(self.device)
40
+
41
+ # Generate response
42
+ with torch.no_grad():
43
+ outputs = self.model.generate(**inputs, max_new_tokens=max_new_tokens, pad_token_id=self.tokenizer.pad_token_id)
44
+
45
+ # Decode and return the response
46
+ response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
47
+ return response_text