C2MV commited on
Commit
4aa57d2
·
verified ·
1 Parent(s): ef6aa08

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +6 -10
models.py CHANGED
@@ -4,14 +4,6 @@ import torch
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  from sentence_transformers import SentenceTransformer
6
  from config import EMBEDDING_MODEL_NAME
7
- from pydantic import BaseModel
8
-
9
- # Clase para los modelos (opcional, si deseas utilizar pydantic)
10
- class Models(BaseModel):
11
- embedding_model: SentenceTransformer
12
- tokenizer: AutoTokenizer
13
- yi_coder_model: AutoModelForCausalLM
14
- device: torch.device
15
 
16
  # Cargar el modelo de embeddings
17
  def load_embedding_model():
@@ -22,7 +14,11 @@ def load_embedding_model():
22
  # Cargar el modelo Yi-Coder
23
  def load_yi_coder_model():
24
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25
- model_path = "01-ai/Yi-Coder-9B-Chat" # Asegúrate de que esta ruta sea correcta
26
  tokenizer = AutoTokenizer.from_pretrained(model_path)
27
- yi_coder_model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16).to(device).eval()
 
 
 
 
28
  return tokenizer, yi_coder_model, device
 
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  from sentence_transformers import SentenceTransformer
6
  from config import EMBEDDING_MODEL_NAME
 
 
 
 
 
 
 
 
7
 
8
  # Cargar el modelo de embeddings
9
  def load_embedding_model():
 
14
  # Cargar el modelo Yi-Coder
15
  def load_yi_coder_model():
16
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
+ model_path = "01-ai/Yi-Coder-9B-Chat" # Asegúrate de que esta ruta sea correcta y que el modelo esté disponible
18
  tokenizer = AutoTokenizer.from_pretrained(model_path)
19
+ yi_coder_model = AutoModelForCausalLM.from_pretrained(
20
+ model_path,
21
+ torch_dtype=torch.float16,
22
+ low_cpu_mem_usage=True # Opcional: ayuda a reducir el uso de memoria al cargar el modelo
23
+ ).to(device).eval()
24
  return tokenizer, yi_coder_model, device