C2MV commited on
Commit
c582a83
1 Parent(s): bc829c9

Create models.py

Browse files
Files changed (1) hide show
  1. models.py +28 -0
models.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # models.py
2
+
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ from sentence_transformers import SentenceTransformer
6
+ from config import EMBEDDING_MODEL_NAME
7
+ from pydantic import BaseModel
8
+
9
+ # Clase para los modelos (opcional, si deseas utilizar pydantic)
10
+ class Models(BaseModel):
11
+ embedding_model: SentenceTransformer
12
+ tokenizer: AutoTokenizer
13
+ yi_coder_model: AutoModelForCausalLM
14
+ device: torch.device
15
+
16
+ # Cargar el modelo de embeddings
17
+ def load_embedding_model():
18
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
+ embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME, device=device)
20
+ return embedding_model
21
+
22
+ # Cargar el modelo Yi-Coder
23
+ def load_yi_coder_model():
24
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25
+ model_path = "01-ai/Yi-Coder-9B-Chat" # Asegúrate de que esta ruta sea correcta
26
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
27
+ yi_coder_model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16).to(device).eval()
28
+ return tokenizer, yi_coder_model, device