Jeney commited on
Commit
f8a58cc
1 Parent(s): aaebaaf

handler.py BASIC

Browse files
Files changed (1) hide show
  1. handler.py +37 -0
handler.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import io
3
+
4
+ from typing import Any, Dict
5
+ from PIL import Image
6
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
7
+
8
+
9
+ class EndpointHandler:
10
+ def __init__(self, path=""):
11
+ # load model and processor from path
12
+ self.processor = DonutProcessor.from_pretrained("debu-das/donut_receipt_v2.29")
13
+ self.model = VisionEncoderDecoderModel.from_pretrained("debu-das/donut_receipt_v2.29")
14
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
15
+
16
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
17
+ # process input
18
+ inputs = data.pop("inputs", data)
19
+ image = inputs["image"]
20
+ image = Image.open(io.BytesIO(eval(image)))
21
+ text = inputs["text"]
22
+
23
+ # preprocess
24
+ encoding = self.processor(image, return_tensors="pt")
25
+ outputs = self.model(**encoding)
26
+ # postprocess the prediction
27
+ logits = outputs.logits
28
+ best_idx = logits.argmax(-1).item()
29
+ best_answer = self.model.config.id2label[best_idx]
30
+ probabilities = torch.softmax(logits, dim=-1)[0]
31
+ id2label = self.model.config.id2label
32
+ answers = []
33
+ for idx, prob in enumerate(probabilities):
34
+ answer = id2label[idx]
35
+ answer_score = float(prob)
36
+ answers.append({"answer": answer, "answer_score": answer_score})
37
+ return {"best_answer": best_answer, "answers": answers}