omkar56 commited on
Commit
f486b06
·
1 Parent(s): f105e35

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +263 -43
main.py CHANGED
@@ -1,16 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
 
2
  os.system("sudo apt-get install xclip")
 
 
3
  import nltk
 
 
 
 
 
4
  from fastapi import FastAPI, File, Request, UploadFile, Body, Depends, HTTPException
5
  from fastapi.security.api_key import APIKeyHeader
6
  from typing import Optional, Annotated
7
  from fastapi.encoders import jsonable_encoder
8
  from PIL import Image
9
  from io import BytesIO
10
- import pytesseract
11
- from nltk.tokenize import sent_tokenize
12
- from transformers import MarianMTModel, MarianTokenizer
13
- nltk.download('punkt')
14
 
15
  API_KEY = os.environ.get("API_KEY")
16
 
@@ -22,10 +124,52 @@ def get_api_key(api_key: Optional[str] = Depends(api_key_header)):
22
  raise HTTPException(status_code=401, detail="Unauthorized access")
23
  return api_key
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  # Image path
26
  img_dir = "./data"
 
27
  # Get tesseract language list
28
  choices = os.popen('tesseract --list-langs').read().split('\n')[1:-1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  # Convert tesseract language list to pytesseract language
30
  def ocr_lang(lang_list):
31
  lang_str = ""
@@ -38,6 +182,8 @@ def ocr_lang(lang_list):
38
 
39
  lang_str = "".join(lang_list[:-1])
40
  return lang_str
 
 
41
  # ocr tesseract
42
  def ocr_tesseract(img, languages):
43
  print("[img]", img)
@@ -45,52 +191,126 @@ def ocr_tesseract(img, languages):
45
  ocr_str = pytesseract.image_to_string(img, lang=ocr_lang(languages))
46
  return ocr_str
47
 
48
- @app.post("/api/ocr", response_model=dict)
49
- async def ocr(
50
- api_key: str = Depends(get_api_key),
51
- image: UploadFile = File(...),
52
- # languages: list = Body(["eng"])
53
- ):
54
-
 
 
55
  try:
56
- content = await image.read()
57
- image = Image.open(BytesIO(content))
58
- print("[image]",image)
59
- if hasattr(pytesseract, "image_to_string"):
60
- print("Image to string function is available")
61
- # print(pytesseract.image_to_string(image, lang = 'eng'))
62
- text = ocr_tesseract(image, ['eng'])
63
- else:
64
- print("Image to string function is not available")
65
- # text = pytesseract.image_to_string(image, lang="+".join(languages))
66
  except Exception as e:
67
- return {"error": str(e)}, 500
 
68
 
69
- return {"ImageText": "text"}
70
 
71
- @app.post("/api/translate", response_model=dict)
72
- async def translate(
73
- api_key: str = Depends(get_api_key),
74
- text: str = Body(...),
75
- src: str = "en",
76
- trg: str = "zh",
77
- ):
78
- if api_key != API_KEY:
79
- return {"error": "Invalid API key"}, 401
80
 
81
- tokenizer, model = get_model(src, trg)
82
 
83
- translated_text = ""
84
- for sentence in sent_tokenize(text):
85
- translated_sub = model.generate(**tokenizer(sentence, return_tensors="pt"))[0]
86
- translated_text += tokenizer.decode(translated_sub, skip_special_tokens=True) + "\n"
 
87
 
88
- return jsonable_encoder({"translated_text": translated_text})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- def get_model(src: str, trg: str):
91
- model_name = f"Helsinki-NLP/opus-mt-{src}-{trg}"
92
- tokenizer = MarianTokenizer.from_pretrained(model_name)
93
- model = MarianMTModel.from_pretrained(model_name)
94
- return tokenizer, model
95
 
 
 
96
 
 
1
+ # import os
2
+ # os.system("sudo apt-get install xclip")
3
+ # import nltk
4
+ # from fastapi import FastAPI, File, Request, UploadFile, Body, Depends, HTTPException
5
+ # from fastapi.security.api_key import APIKeyHeader
6
+ # from typing import Optional, Annotated
7
+ # from fastapi.encoders import jsonable_encoder
8
+ # from PIL import Image
9
+ # from io import BytesIO
10
+ # import pytesseract
11
+ # from nltk.tokenize import sent_tokenize
12
+ # from transformers import MarianMTModel, MarianTokenizer
13
+ # nltk.download('punkt')
14
+
15
+ # API_KEY = os.environ.get("API_KEY")
16
+
17
+ # app = FastAPI()
18
+ # api_key_header = APIKeyHeader(name="api_key", auto_error=False)
19
+
20
+ # def get_api_key(api_key: Optional[str] = Depends(api_key_header)):
21
+ # if api_key is None or api_key != API_KEY:
22
+ # raise HTTPException(status_code=401, detail="Unauthorized access")
23
+ # return api_key
24
+
25
+ # # Image path
26
+ # img_dir = "./data"
27
+ # # Get tesseract language list
28
+ # choices = os.popen('tesseract --list-langs').read().split('\n')[1:-1]
29
+ # # Convert tesseract language list to pytesseract language
30
+ # def ocr_lang(lang_list):
31
+ # lang_str = ""
32
+ # lang_len = len(lang_list)
33
+ # if lang_len == 1:
34
+ # return lang_list[0]
35
+ # else:
36
+ # for i in range(lang_len):
37
+ # lang_list.insert(lang_len - i, "+")
38
+
39
+ # lang_str = "".join(lang_list[:-1])
40
+ # return lang_str
41
+ # # ocr tesseract
42
+ # def ocr_tesseract(img, languages):
43
+ # print("[img]", img)
44
+ # print("[languages]", languages)
45
+ # ocr_str = pytesseract.image_to_string(img, lang=ocr_lang(languages))
46
+ # return ocr_str
47
+
48
+ # @app.post("/api/ocr", response_model=dict)
49
+ # async def ocr(
50
+ # api_key: str = Depends(get_api_key),
51
+ # image: UploadFile = File(...),
52
+ # # languages: list = Body(["eng"])
53
+ # ):
54
+
55
+ # try:
56
+ # content = await image.read()
57
+ # image = Image.open(BytesIO(content))
58
+ # print("[image]",image)
59
+ # if hasattr(pytesseract, "image_to_string"):
60
+ # print("Image to string function is available")
61
+ # # print(pytesseract.image_to_string(image, lang = 'eng'))
62
+ # text = ocr_tesseract(image, ['eng'])
63
+ # else:
64
+ # print("Image to string function is not available")
65
+ # # text = pytesseract.image_to_string(image, lang="+".join(languages))
66
+ # except Exception as e:
67
+ # return {"error": str(e)}, 500
68
+
69
+ # return {"ImageText": "text"}
70
+
71
+ # @app.post("/api/translate", response_model=dict)
72
+ # async def translate(
73
+ # api_key: str = Depends(get_api_key),
74
+ # text: str = Body(...),
75
+ # src: str = "en",
76
+ # trg: str = "zh",
77
+ # ):
78
+ # if api_key != API_KEY:
79
+ # return {"error": "Invalid API key"}, 401
80
+
81
+ # tokenizer, model = get_model(src, trg)
82
+
83
+ # translated_text = ""
84
+ # for sentence in sent_tokenize(text):
85
+ # translated_sub = model.generate(**tokenizer(sentence, return_tensors="pt"))[0]
86
+ # translated_text += tokenizer.decode(translated_sub, skip_special_tokens=True) + "\n"
87
+
88
+ # return jsonable_encoder({"translated_text": translated_text})
89
+
90
+ # def get_model(src: str, trg: str):
91
+ # model_name = f"Helsinki-NLP/opus-mt-{src}-{trg}"
92
+ # tokenizer = MarianTokenizer.from_pretrained(model_name)
93
+ # model = MarianMTModel.from_pretrained(model_name)
94
+ # return tokenizer, model
95
+
96
+ # OCR Translate v0.2
97
+
98
+
99
  import os
100
+
101
  os.system("sudo apt-get install xclip")
102
+
103
+ # import gradio as gr
104
  import nltk
105
+ import pyclip
106
+ import pytesseract
107
+ from nltk.tokenize import sent_tokenize
108
+ from transformers import MarianMTModel, MarianTokenizer
109
+ # Added below code
110
  from fastapi import FastAPI, File, Request, UploadFile, Body, Depends, HTTPException
111
  from fastapi.security.api_key import APIKeyHeader
112
  from typing import Optional, Annotated
113
  from fastapi.encoders import jsonable_encoder
114
  from PIL import Image
115
  from io import BytesIO
 
 
 
 
116
 
117
  API_KEY = os.environ.get("API_KEY")
118
 
 
124
  raise HTTPException(status_code=401, detail="Unauthorized access")
125
  return api_key
126
 
127
+ @app.post("/api/ocr", response_model=dict)
128
+ async def ocr(
129
+ api_key: str = Depends(get_api_key),
130
+ image: UploadFile = File(...),
131
+ # languages: list = Body(["eng"])
132
+ ):
133
+ try:
134
+ content = await image.read()
135
+ image = Image.open(BytesIO(content))
136
+ print("[image]",image)
137
+ if hasattr(pytesseract, "image_to_string"):
138
+ print("Image to string function is available")
139
+ # print(pytesseract.image_to_string(image, lang = 'eng'))
140
+ text = ocr_tesseract(image, ['eng'])
141
+ else:
142
+ print("Image to string function is not available")
143
+ # text = pytesseract.image_to_string(image, lang="+".join(languages))
144
+ except Exception as e:
145
+ return {"error": str(e)}, 500
146
+
147
+ return {"ImageText": "text"}
148
+
149
+ nltk.download('punkt')
150
+
151
+ OCR_TR_DESCRIPTION = '''# OCR Translate v0.2
152
+ <div id="content_align">OCR translation system based on Tesseract</div>'''
153
+
154
  # Image path
155
  img_dir = "./data"
156
+
157
  # Get tesseract language list
158
  choices = os.popen('tesseract --list-langs').read().split('\n')[1:-1]
159
+
160
+
161
+ # Translation model selection
162
+ def model_choice(src="en", trg="zh"):
163
+ # https://huggingface.co/Helsinki-NLP/opus-mt-zh-en
164
+ # https://huggingface.co/Helsinki-NLP/opus-mt-en-zh
165
+ model_name = f"Helsinki-NLP/opus-mt-{src}-{trg}" # Model name
166
+
167
+ tokenizer = MarianTokenizer.from_pretrained(model_name) # tokenizer
168
+ model = MarianMTModel.from_pretrained(model_name) # Model
169
+
170
+ return tokenizer, model
171
+
172
+
173
  # Convert tesseract language list to pytesseract language
174
  def ocr_lang(lang_list):
175
  lang_str = ""
 
182
 
183
  lang_str = "".join(lang_list[:-1])
184
  return lang_str
185
+
186
+
187
  # ocr tesseract
188
  def ocr_tesseract(img, languages):
189
  print("[img]", img)
 
191
  ocr_str = pytesseract.image_to_string(img, lang=ocr_lang(languages))
192
  return ocr_str
193
 
194
+
195
+ # Clear
196
+ def clear_content():
197
+ return None
198
+
199
+
200
+ # copy to clipboard
201
+ def cp_text(input_text):
202
+ # sudo apt-get install xclip
203
  try:
204
+ pyclip.copy(input_text)
 
 
 
 
 
 
 
 
 
205
  except Exception as e:
206
+ print("sudo apt-get install xclip")
207
+ print(e)
208
 
 
209
 
210
+ # clear clipboard
211
+ def cp_clear():
212
+ pyclip.clear()
 
 
 
 
 
 
213
 
 
214
 
215
+ # translate
216
+ def translate(input_text, inputs_transStyle):
217
+ # reference:https://huggingface.co/docs/transformers/model_doc/marian
218
+ if input_text is None or input_text == "":
219
+ return "System prompt: There is no content to translate!"
220
 
221
+ # Select translation model
222
+ trans_src, trans_trg = inputs_transStyle.split("-")[0], inputs_transStyle.split("-")[1]
223
+ tokenizer, model = model_choice(trans_src, trans_trg)
224
+
225
+ translate_text = ""
226
+ input_text_list = input_text.split("\n\n")
227
+
228
+ translate_text_list_tmp = []
229
+ for i in range(len(input_text_list)):
230
+ if input_text_list[i] != "":
231
+ translate_text_list_tmp.append(input_text_list[i])
232
+
233
+ for i in range(len(translate_text_list_tmp)):
234
+ translated_sub = model.generate(
235
+ **tokenizer(sent_tokenize(translate_text_list_tmp[i]), return_tensors="pt", truncation=True, padding=True))
236
+ tgt_text_sub = [tokenizer.decode(t, skip_special_tokens=True) for t in translated_sub]
237
+ translate_text_sub = "".join(tgt_text_sub)
238
+ translate_text = translate_text + "\n\n" + translate_text_sub
239
+
240
+ return translate_text[2:]
241
+
242
+
243
+ # def main():
244
+
245
+ # with gr.Blocks(css='style.css') as ocr_tr:
246
+ # gr.Markdown(OCR_TR_DESCRIPTION)
247
+
248
+ # # -------------- OCR text extraction --------------
249
+ # with gr.Box():
250
+
251
+ # with gr.Row():
252
+ # gr.Markdown("### Step 01: Text Extraction")
253
+
254
+ # with gr.Row():
255
+ # with gr.Column():
256
+ # with gr.Row():
257
+ # inputs_img = gr.Image(image_mode="RGB", source="upload", type="pil", label="image")
258
+ # with gr.Row():
259
+ # inputs_lang = gr.CheckboxGroup(choices=["chi_sim", "eng"],
260
+ # type="value",
261
+ # value=['eng'],
262
+ # label='language')
263
+
264
+ # with gr.Row():
265
+ # clear_img_btn = gr.Button('Clear')
266
+ # ocr_btn = gr.Button(value='OCR Extraction', variant="primary")
267
+
268
+ # with gr.Column():
269
+ # with gr.Row():
270
+ # outputs_text = gr.Textbox(label="Extract content", lines=20)
271
+ # with gr.Row():
272
+ # inputs_transStyle = gr.Radio(choices=["zh-en", "en-zh"],
273
+ # type="value",
274
+ # value="zh-en",
275
+ # label='translation mode')
276
+ # with gr.Row():
277
+ # clear_text_btn = gr.Button('Clear')
278
+ # translate_btn = gr.Button(value='Translate', variant="primary")
279
+
280
+ # with gr.Row():
281
+ # example_list = [["./data/test.png", ["eng"]], ["./data/test02.png", ["eng"]],
282
+ # ["./data/test03.png", ["chi_sim"]]]
283
+ # gr.Examples(example_list, [inputs_img, inputs_lang], outputs_text, ocr_tesseract, cache_examples=False)
284
+
285
+ # # -------------- translate --------------
286
+ # with gr.Box():
287
+
288
+ # with gr.Row():
289
+ # gr.Markdown("### Step 02: Translation")
290
+
291
+ # with gr.Row():
292
+ # outputs_tr_text = gr.Textbox(label="Translate Content", lines=20)
293
+
294
+ # with gr.Row():
295
+ # cp_clear_btn = gr.Button(value='Clear Clipboard')
296
+ # cp_btn = gr.Button(value='Copy to clipboard', variant="primary")
297
+
298
+ # # ---------------------- OCR Tesseract ----------------------
299
+ # ocr_btn.click(fn=ocr_tesseract, inputs=[inputs_img, inputs_lang], outputs=[
300
+ # outputs_text,])
301
+ # clear_img_btn.click(fn=clear_content, inputs=[], outputs=[inputs_img])
302
+
303
+ # # ---------------------- translate ----------------------
304
+ # translate_btn.click(fn=translate, inputs=[outputs_text, inputs_transStyle], outputs=[outputs_tr_text])
305
+ # clear_text_btn.click(fn=clear_content, inputs=[], outputs=[outputs_text])
306
+
307
+ # # ---------------------- copy to clipboard ----------------------
308
+ # cp_btn.click(fn=cp_text, inputs=[outputs_tr_text], outputs=[])
309
+ # cp_clear_btn.click(fn=cp_clear, inputs=[], outputs=[])
310
+
311
+ # ocr_tr.launch(inbrowser=True)
312
 
 
 
 
 
 
313
 
314
+ # if __name__ == '__main__':
315
+ # main()
316