C2MV commited on
Commit
0788e60
1 Parent(s): 9fa15f3

Update interface.py

Browse files
Files changed (1) hide show
  1. interface.py +11 -2
interface.py CHANGED
@@ -22,20 +22,27 @@ tokenizer = AutoTokenizer.from_pretrained(model_path)
22
  model = AutoModelForCausalLM.from_pretrained(model_path)
23
  # No movemos el modelo al dispositivo aquí
24
 
 
 
25
  @spaces.GPU(duration=100)
26
- def generate_analysis(prompt, max_length=MAX_LENGTH, device=None):
27
  try:
28
  if device is None:
29
  device = torch.device('cpu')
 
 
30
  if next(model.parameters()).device != device:
31
  model.to(device)
 
 
32
  input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
33
  max_gen_length = min(max_length + input_ids.size(1), model.config.max_position_embeddings)
34
 
 
35
  generated_ids = model.generate(
36
  input_ids=input_ids,
37
  max_length=max_gen_length,
38
- temperature=TEMPERATURE,
39
  num_return_sequences=1,
40
  no_repeat_ngram_size=2,
41
  early_stopping=True
@@ -44,6 +51,8 @@ def generate_analysis(prompt, max_length=MAX_LENGTH, device=None):
44
  output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
45
  analysis = output_text[len(prompt):].strip()
46
  return analysis
 
 
47
  except Exception as e:
48
  return f"Ocurrió un error durante el análisis: {e}"
49
 
 
22
  model = AutoModelForCausalLM.from_pretrained(model_path)
23
  # No movemos el modelo al dispositivo aquí
24
 
25
+ from decorators import spaces
26
+
27
  @spaces.GPU(duration=100)
28
+ def generate_analysis(prompt, max_length=1024, device=None):
29
  try:
30
  if device is None:
31
  device = torch.device('cpu')
32
+
33
+ # Mover el modelo al dispositivo adecuado (GPU o CPU)
34
  if next(model.parameters()).device != device:
35
  model.to(device)
36
+
37
+ # Preparar los datos de entrada en el dispositivo correcto
38
  input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
39
  max_gen_length = min(max_length + input_ids.size(1), model.config.max_position_embeddings)
40
 
41
+ # Generar el texto
42
  generated_ids = model.generate(
43
  input_ids=input_ids,
44
  max_length=max_gen_length,
45
+ temperature=0.7,
46
  num_return_sequences=1,
47
  no_repeat_ngram_size=2,
48
  early_stopping=True
 
51
  output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
52
  analysis = output_text[len(prompt):].strip()
53
  return analysis
54
+ except RuntimeError as e:
55
+ return f"Error durante la ejecución: {str(e)}"
56
  except Exception as e:
57
  return f"Ocurrió un error durante el análisis: {e}"
58