abhisheksan commited on
Commit
51ed73b
1 Parent(s): 34aad78

Refactor model loading to use a consistent model name variable in PoetryGenerationService

Browse files
Files changed (1) hide show
  1. app/services/poetry_generation.py +5 -5
app/services/poetry_generation.py CHANGED
@@ -23,10 +23,10 @@ class ModelManager:
23
  def __init__(self):
24
 
25
  # Initialize tokenizer and model
26
- self.tokenizer = AutoTokenizer.from_pretrained(self._model_name)
27
  self.tokenizer.pad_token = self.tokenizer.eos_token
28
  self.model = AutoModelForCausalLM.from_pretrained(
29
- self._model_name,
30
  torch_dtype=torch.float16,
31
  device_map="auto"
32
  )
@@ -54,7 +54,7 @@ def get_hf_token() -> str:
54
  "Please set your Hugging Face access token."
55
  )
56
  return token
57
-
58
  class PoetryGenerationService:
59
  def __init__(self):
60
  # Get model manager instance
@@ -66,11 +66,11 @@ class PoetryGenerationService:
66
  """Preload the models during application startup"""
67
  try:
68
  # Initialize tokenizer and model
69
- self.tokenizer = AutoTokenizer.from_pretrained(self._model_name)
70
  self.tokenizer.pad_token = self.tokenizer.eos_token
71
 
72
  self.model = AutoModelForCausalLM.from_pretrained(
73
- self._model_name,
74
  torch_dtype=torch.float16,
75
  device_map="auto"
76
  )
 
23
  def __init__(self):
24
 
25
  # Initialize tokenizer and model
26
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
27
  self.tokenizer.pad_token = self.tokenizer.eos_token
28
  self.model = AutoModelForCausalLM.from_pretrained(
29
+ model_name,
30
  torch_dtype=torch.float16,
31
  device_map="auto"
32
  )
 
54
  "Please set your Hugging Face access token."
55
  )
56
  return token
57
+ model_name = "meta-llama/Llama-3.2-1B-Instruct"
58
  class PoetryGenerationService:
59
  def __init__(self):
60
  # Get model manager instance
 
66
  """Preload the models during application startup"""
67
  try:
68
  # Initialize tokenizer and model
69
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
70
  self.tokenizer.pad_token = self.tokenizer.eos_token
71
 
72
  self.model = AutoModelForCausalLM.from_pretrained(
73
+ model_name,
74
  torch_dtype=torch.float16,
75
  device_map="auto"
76
  )