Vaibhav Srivastav commited on
Commit
f5de0f6
·
1 Parent(s): 973e88f
Files changed (1) hide show
  1. handler.py +2 -5
handler.py CHANGED
@@ -5,10 +5,8 @@ import torch
5
  class EndpointHandler:
6
  def __init__(self, path=""):
7
  # load model and processor from path
8
- self.processor = AutoProcessor.from_pretrained(path)
9
- self.model = MusicgenForConditionalGeneration.from_pretrained(path)
10
- # self.model = AutoModelForSeq2SeqLM.from_pretrained(path, device_map="auto", load_in_8bit=True)
11
- # self.tokenizer = AutoTokenizer.from_pretrained(path)
12
 
13
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
14
  """
@@ -21,7 +19,6 @@ class EndpointHandler:
21
  parameters = data.pop("parameters", None)
22
 
23
  # preprocess
24
- # input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids
25
  inputs = processor(
26
  text=inputs,
27
  padding=True,
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
7
  # load model and processor from path
8
+ self.processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
9
+ self.model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
 
 
10
 
11
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
12
  """
 
19
  parameters = data.pop("parameters", None)
20
 
21
  # preprocess
 
22
  inputs = processor(
23
  text=inputs,
24
  padding=True,