Jeney commited on
Commit
1c7db93
1 Parent(s): f8a58cc

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +32 -17
handler.py CHANGED
@@ -1,5 +1,6 @@
1
  import torch
2
  import io
 
3
 
4
  from typing import Any, Dict
5
  from PIL import Image
@@ -18,20 +19,34 @@ class EndpointHandler:
18
  inputs = data.pop("inputs", data)
19
  image = inputs["image"]
20
  image = Image.open(io.BytesIO(eval(image)))
21
- text = inputs["text"]
22
-
23
- # preprocess
24
- encoding = self.processor(image, return_tensors="pt")
25
- outputs = self.model(**encoding)
26
- # postprocess the prediction
27
- logits = outputs.logits
28
- best_idx = logits.argmax(-1).item()
29
- best_answer = self.model.config.id2label[best_idx]
30
- probabilities = torch.softmax(logits, dim=-1)[0]
31
- id2label = self.model.config.id2label
32
- answers = []
33
- for idx, prob in enumerate(probabilities):
34
- answer = id2label[idx]
35
- answer_score = float(prob)
36
- answers.append({"answer": answer, "answer_score": answer_score})
37
- return {"best_answer": best_answer, "answers": answers}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import io
3
+ import re
4
 
5
  from typing import Any, Dict
6
  from PIL import Image
 
19
  inputs = data.pop("inputs", data)
20
  image = inputs["image"]
21
  image = Image.open(io.BytesIO(eval(image)))
22
+ return self.process_document(image)
23
+
24
+
25
+ def process_document(self, image):
26
+ # prepare encoder inputs
27
+ pixel_values = self.processor(image, return_tensors="pt").pixel_values
28
+
29
+ # prepare decoder inputs
30
+ task_prompt = "<s_cord-v2>"
31
+ decoder_input_ids = self.processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
32
+
33
+ # generate answer
34
+ outputs = self.model.generate(
35
+ pixel_values.to(self.device),
36
+ decoder_input_ids=decoder_input_ids.to(self.device),
37
+ max_length=self.model.decoder.config.max_position_embeddings,
38
+ early_stopping=True,
39
+ pad_token_id=self.processor.tokenizer.pad_token_id,
40
+ eos_token_id=self.processor.tokenizer.eos_token_id,
41
+ use_cache=True,
42
+ num_beams=1,
43
+ bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
44
+ return_dict_in_generate=True,
45
+ )
46
+
47
+ # postprocess
48
+ sequence = self.processor.batch_decode(outputs.sequences)[0]
49
+ sequence = sequence.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "")
50
+ sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
51
+
52
+ return self.processor.token2json(sequence)