dduy193 commited on
Commit
00038d3
β€’
1 Parent(s): 02dd597

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -14
app.py CHANGED
@@ -56,8 +56,6 @@ model2.config.id2label = mapping
56
  model3 = models.resnet101(pretrained=False)
57
  model3.fc = torch.nn.Linear(2048, len(genres))
58
 
59
- device = torch.device('cpu')
60
-
61
 
62
  # In[5]:
63
 
@@ -94,19 +92,19 @@ class Multimodal(torch.nn.Module):
94
 
95
  model = Multimodal(model1, model2, model3)
96
  model.load_state_dict(torch.load('multimodel.pt', map_location=torch.device('cpu')))
97
- model.to(device)
98
  model.eval()
99
- model_gen.to(device)
100
- model_gen.eval()
101
 
102
 
103
  # In[8]:
104
 
105
 
106
- def generate_plot(title: str, model: AutoModelForSeq2SeqLM, tokenizer: AutoTokenizer) -> str:
107
  quote = 'What is the story of the movie {}?'
 
 
108
 
109
- input_ids = tokenizer(quote.format(title), return_tensors='pt')
110
  output = model.generate(input_ids, max_length=256, do_sample=True, temperature=0.09)
111
  return tokenizer.decode(output[0], skip_special_tokens=True)
112
 
@@ -117,24 +115,25 @@ def generate_plot(title: str, model: AutoModelForSeq2SeqLM, tokenizer: AutoToken
117
  def inference(title, image,
118
  tokenizer1=tokenizer1, tokenizer2=tokenizer2, tokenizer_gen=tokenizer_gen,
119
  model_gen=model_gen, model=model,
120
- genres=genres):
121
  title_input = tokenizer1(title, return_tensors='pt', padding=True, truncation=True)
122
- title_input_ids = title_input['input_ids']
123
- title_attention_mask = title_input['attention_mask']
124
 
125
- plot = generate_plot(title, model_gen, tokenizer_gen)
126
  plot_input = tokenizer2(plot, return_tensors='pt', padding=True, truncation=True)
127
- plot_input_ids = plot_input['input_ids']
128
- plot_attention_mask = plot_input['attention_mask']
129
 
130
  # If image is not uploaded
131
  if image is None:
132
- image_input = torch.zeros((1, 3, 224, 224))
133
 
134
  else:
135
  image_input = image.resize((224, 224))
136
  image_input = v2.ToTensor()(image_input)
137
  image_input = image_input.unsqueeze(0)
 
138
 
139
  output = model(title_input_ids, title_attention_mask, plot_input_ids, plot_attention_mask, image_input)
140
  output = torch.sigmoid(output)
@@ -163,3 +162,4 @@ app = gr.Interface(fn=inference, inputs=["text", "pil"], outputs="text", title="
163
 
164
 
165
  app.launch(share=True)
 
 
56
  model3 = models.resnet101(pretrained=False)
57
  model3.fc = torch.nn.Linear(2048, len(genres))
58
 
 
 
59
 
60
  # In[5]:
61
 
 
92
 
93
  model = Multimodal(model1, model2, model3)
94
  model.load_state_dict(torch.load('multimodel.pt', map_location=torch.device('cpu')))
 
95
  model.eval()
96
+ device = torch.device('cpu')
 
97
 
98
 
99
  # In[8]:
100
 
101
 
102
+ def generate_plot(title: str, model: AutoModelForSeq2SeqLM, tokenizer: AutoTokenizer, device) -> str:
103
  quote = 'What is the story of the movie {}?'
104
+ model_gen.to(device)
105
+ model_gen.eval()
106
 
107
+ input_ids = tokenizer(quote.format(title), return_tensors='pt').input_ids.to(device)
108
  output = model.generate(input_ids, max_length=256, do_sample=True, temperature=0.09)
109
  return tokenizer.decode(output[0], skip_special_tokens=True)
110
 
 
115
  def inference(title, image,
116
  tokenizer1=tokenizer1, tokenizer2=tokenizer2, tokenizer_gen=tokenizer_gen,
117
  model_gen=model_gen, model=model,
118
+ genres=genres, device=device):
119
  title_input = tokenizer1(title, return_tensors='pt', padding=True, truncation=True)
120
+ title_input_ids = title_input['input_ids'].to(device)
121
+ title_attention_mask = title_input['attention_mask'].to(device)
122
 
123
+ plot = generate_plot(title, model_gen, tokenizer_gen, device)
124
  plot_input = tokenizer2(plot, return_tensors='pt', padding=True, truncation=True)
125
+ plot_input_ids = plot_input['input_ids'].to(device)
126
+ plot_attention_mask = plot_input['attention_mask'].to(device)
127
 
128
  # If image is not uploaded
129
  if image is None:
130
+ image_input = torch.zeros((1, 3, 224, 224)).to(device)
131
 
132
  else:
133
  image_input = image.resize((224, 224))
134
  image_input = v2.ToTensor()(image_input)
135
  image_input = image_input.unsqueeze(0)
136
+ image_input = image_input.to(device)
137
 
138
  output = model(title_input_ids, title_attention_mask, plot_input_ids, plot_attention_mask, image_input)
139
  output = torch.sigmoid(output)
 
162
 
163
 
164
  app.launch(share=True)
165
+