Update modeling_GOT.py
Browse files- modeling_GOT.py +2 -2
modeling_GOT.py
CHANGED
@@ -484,7 +484,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
484 |
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
485 |
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
486 |
|
487 |
-
def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag = False):
|
488 |
|
489 |
self.disable_torch_init()
|
490 |
|
@@ -505,7 +505,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
505 |
if ocr_type == 'format':
|
506 |
qs = 'OCR with format: '
|
507 |
elif ocr_type == 'VQA':
|
508 |
-
qs = 'Answer the following Question :'
|
509 |
else:
|
510 |
qs = 'OCR: '
|
511 |
|
|
|
484 |
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
485 |
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
486 |
|
487 |
+
def chat(self, tokenizer, image_file, ocr_type, vqa_question='' ,ocr_box='', ocr_color='', render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag = False):
|
488 |
|
489 |
self.disable_torch_init()
|
490 |
|
|
|
505 |
if ocr_type == 'format':
|
506 |
qs = 'OCR with format: '
|
507 |
elif ocr_type == 'VQA':
|
508 |
+
qs = 'Answer the following Question :' + str(vqa_question)
|
509 |
else:
|
510 |
qs = 'OCR: '
|
511 |
|