Update handler.py
Browse files- handler.py +4 -4
handler.py
CHANGED
@@ -10,8 +10,8 @@ from transformers import DonutProcessor, VisionEncoderDecoderModel
|
|
10 |
class EndpointHandler:
|
11 |
def __init__(self, path=""):
|
12 |
# load model and processor from path
|
13 |
-
self.processor = DonutProcessor.from_pretrained("
|
14 |
-
self.model = VisionEncoderDecoderModel.from_pretrained("
|
15 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
16 |
|
17 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
@@ -22,7 +22,7 @@ class EndpointHandler:
|
|
22 |
return self.process_document(image)
|
23 |
|
24 |
|
25 |
-
def process_document(self, image):
|
26 |
# prepare encoder inputs
|
27 |
pixel_values = self.processor(image, return_tensors="pt").pixel_values
|
28 |
|
@@ -49,4 +49,4 @@ class EndpointHandler:
|
|
49 |
sequence = sequence.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "")
|
50 |
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
|
51 |
|
52 |
-
return self.processor.token2json(sequence)
|
|
|
10 |
class EndpointHandler:
|
11 |
def __init__(self, path=""):
|
12 |
# load model and processor from path
|
13 |
+
self.processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
|
14 |
+
self.model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
|
15 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
16 |
|
17 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
22 |
return self.process_document(image)
|
23 |
|
24 |
|
25 |
+
def process_document(self, image:Image) -> dict[str, Any]:
|
26 |
# prepare encoder inputs
|
27 |
pixel_values = self.processor(image, return_tensors="pt").pixel_values
|
28 |
|
|
|
49 |
sequence = sequence.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "")
|
50 |
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
|
51 |
|
52 |
+
return self.processor.token2json(sequence)
|