Svngoku commited on
Commit
f2a7627
·
verified ·
1 Parent(s): 23df37a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -147
app.py CHANGED
@@ -6,19 +6,19 @@ import shutil
6
  import time
7
  import pymupdf as fitz
8
  import logging
9
- import mimetypes
10
- from mistralai import Mistral
11
  from mistralai.models import OCRResponse
12
  from typing import Union, List, Tuple, Optional, Dict
13
  from tenacity import retry, stop_after_attempt, wait_exponential
 
14
  import tempfile
15
 
16
  # Constants
17
- SUPPORTED_IMAGE_TYPES = [".jpg", ".png", ".jpeg", ".avif"]
18
- SUPPORTED_DOCUMENT_TYPES = [".pdf"]
19
  UPLOAD_FOLDER = "./uploads"
20
  MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
21
- MAX_PDF_PAGES = 50 # Not used anymore, kept for reference
22
 
23
  # Configuration
24
  os.makedirs(UPLOAD_FOLDER, exist_ok=True)
@@ -34,7 +34,6 @@ class OCRProcessor:
34
  if not api_key or not isinstance(api_key, str):
35
  raise ValueError("Valid API key must be provided")
36
  self.client = Mistral(api_key=api_key)
37
- self.file_ids_to_delete = []
38
  self._validate_client()
39
 
40
  def _validate_client(self) -> None:
@@ -46,52 +45,93 @@ class OCRProcessor:
46
  raise ValueError(f"API key validation failed: {str(e)}")
47
 
48
  @staticmethod
49
- def _check_file_size(file_path: str) -> None:
50
- if not os.path.exists(file_path):
51
- raise FileNotFoundError(f"File not found: {file_path}")
52
- size = os.path.getsize(file_path)
 
 
 
 
53
  if size > MAX_FILE_SIZE:
54
  raise ValueError(f"File size exceeds {MAX_FILE_SIZE/1024/1024}MB limit")
55
 
56
- def _upload_file_for_ocr(self, file_path: str) -> str:
57
- filename = os.path.basename(file_path)
 
 
58
  try:
59
- with open(file_path, "rb") as f:
60
- uploaded_file = self.client.files.upload(
61
- file={"file_name": filename, "content": f},
62
- purpose="ocr"
63
- )
64
- self.file_ids_to_delete.append(uploaded_file.id)
65
- signed_url = self.client.files.get_signed_url(uploaded_file.id)
66
- return signed_url.url
 
 
 
 
 
 
 
 
67
  except Exception as e:
68
- logger.error(f"Failed to upload file {filename}: {str(e)}")
69
- raise ValueError(f"Failed to upload file: {str(e)}")
70
 
71
  @staticmethod
72
- def _convert_first_page(pdf_path: str) -> Optional[str]:
 
 
 
 
 
 
 
 
 
73
  try:
74
  pdf_document = fitz.open(pdf_path)
75
- if pdf_document.page_count == 0:
76
  pdf_document.close()
77
- return None
78
- page = pdf_document[0]
79
- pix = page.get_pixmap(dpi=100)
80
- img_path = os.path.join(UPLOAD_FOLDER, f"preview_{int(time.time())}.png")
81
- pix.save(img_path)
 
 
82
  pdf_document.close()
83
- return img_path
84
  except Exception as e:
85
- logger.error(f"Error converting first page of {pdf_path}: {str(e)}")
86
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10))
89
- def _call_ocr_api(self, document: Dict) -> OCRResponse:
 
90
  try:
91
  logger.info("Calling OCR API")
92
  response = self.client.ocr.process(
93
  model="mistral-ocr-latest",
94
- document=document,
95
  include_image_base64=True
96
  )
97
  return response
@@ -99,61 +139,64 @@ class OCRProcessor:
99
  logger.error(f"OCR API call failed: {str(e)}")
100
  raise
101
 
102
- def process_file(self, file: gr.File) -> Tuple[str, str]:
103
  """Process uploaded file (image or PDF)."""
104
  if not file:
105
- return "## No file provided", ""
106
- file_path = file.name
107
- self._check_file_size(file_path)
108
- file_name = os.path.basename(file_path)
109
- ext = os.path.splitext(file_name)[1].lower()
110
- try:
111
- if ext in SUPPORTED_IMAGE_TYPES:
112
- mime_type, _ = mimetypes.guess_type(file_path)
113
- if mime_type is None:
114
- mime_type = "image/png"
115
- with open(file_path, "rb") as image_file:
116
- image_data = image_file.read()
117
- base64_encoded = base64.b64encode(image_data).decode('utf-8')
118
- data_url = f"data:{mime_type};base64,{base64_encoded}"
119
- document = {"type": "image_url", "image_url": data_url}
120
- response = self._call_ocr_api(document)
121
- markdown = self._combine_markdown(response)
122
- return markdown, file_path
123
- elif ext in SUPPORTED_DOCUMENT_TYPES:
124
- signed_url = self._upload_file_for_ocr(file_path)
125
- document = {"type": "document_url", "document_url": signed_url}
126
- response = self._call_ocr_api(document)
127
  markdown = self._combine_markdown(response)
128
- return markdown, file_path
129
- else:
130
- return f"## Unsupported file type. Supported: {', '.join(SUPPORTED_IMAGE_TYPES + SUPPORTED_DOCUMENT_TYPES)}", file_path
131
- except Exception as e:
132
- logger.error(f"Error processing file {file_name}: {str(e)}")
133
- return f"## Error processing file: {str(e)}", file_path
134
 
135
- def process_url(self, url: str) -> Tuple[str, str]:
136
  """Process URL (image or PDF)."""
137
  if not url:
138
- return "## No URL provided", ""
139
- parsed_url = url.split('/')[-1] if '/' in url else url
140
- ext = os.path.splitext(parsed_url)[1].lower()
141
- try:
142
- if ext in SUPPORTED_IMAGE_TYPES:
143
- document = {"type": "image_url", "image_url": url}
144
- response = self._call_ocr_api(document)
145
- markdown = self._combine_markdown(response)
146
- return markdown, url
147
- elif ext in SUPPORTED_DOCUMENT_TYPES:
148
- document = {"type": "document_url", "document_url": url}
149
- response = self._call_ocr_api(document)
 
 
 
 
 
 
 
 
150
  markdown = self._combine_markdown(response)
151
- return markdown, url
152
- else:
153
- return f"## Unsupported URL type. Supported: {', '.join(SUPPORTED_IMAGE_TYPES + SUPPORTED_DOCUMENT_TYPES)}", url
154
- except Exception as e:
155
- logger.error(f"Error processing URL {url}: {str(e)}")
156
- return f"## Error processing URL: {str(e)}", url
157
 
158
  @staticmethod
159
  def _combine_markdown(response: OCRResponse) -> str:
@@ -173,53 +216,20 @@ class OCRProcessor:
173
  markdown_parts.append(markdown)
174
  return "\n\n".join(markdown_parts) or "## No text detected"
175
 
176
- def update_file_preview(file):
177
- if not file:
178
- return gr.update(value=[])
179
- ext = os.path.splitext(os.path.basename(file.name))[1].lower()
180
- if ext in SUPPORTED_IMAGE_TYPES:
181
- return gr.update(value=[file.name])
182
- elif ext in SUPPORTED_DOCUMENT_TYPES:
183
- first_page = OCRProcessor._convert_first_page(file.name)
184
- return gr.update(value=[first_page] if first_page else [])
185
- else:
186
- return gr.update(value=[])
187
-
188
- def update_url_preview(url):
189
- if not url:
190
- return gr.update(value=[])
191
- parsed_url = url.split('/')[-1] if '/' in url else url
192
- ext = os.path.splitext(parsed_url)[1].lower()
193
- if ext in SUPPORTED_IMAGE_TYPES:
194
- return gr.update(value=[url])
195
- elif ext == '.pdf': # Only preview PDFs
196
- try:
197
- response = requests.get(url, timeout=10, stream=True)
198
- response.raise_for_status()
199
- with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as temp_pdf:
200
- shutil.copyfileobj(response.raw, temp_pdf)
201
- temp_pdf_path = temp_pdf.name
202
- first_page = OCRProcessor._convert_first_page(temp_pdf_path)
203
- os.unlink(temp_pdf_path)
204
- return gr.update(value=[first_page] if first_page else [])
205
- except Exception as e:
206
- logger.error(f"URL preview error: {str(e)}")
207
- return gr.update(value=[])
208
- else:
209
- return gr.update(value=[])
210
-
211
  def create_interface():
212
  css = """
213
  .output-markdown {font-size: 14px; max-height: 500px; overflow-y: auto;}
214
  .status {color: #666; font-style: italic;}
215
  .preview {max-height: 300px;}
216
  """
 
217
  with gr.Blocks(title="Mistral OCR Demo", css=css) as demo:
218
  gr.Markdown("# Mistral OCR Demo")
219
- gr.Markdown(f"""Process PDFs and images (max {MAX_FILE_SIZE/1024/1024}MB) via upload or URL.
220
- Supported: Images ({', '.join(SUPPORTED_IMAGE_TYPES)}), Documents ({', '.join(SUPPORTED_DOCUMENT_TYPES)}).
221
- View previews and OCR results with embedded images.
222
- Learn more at [Mistral OCR](https://docs.mistral.ai/capabilities/document_ai/basic_ocr).""")
 
223
 
224
  # API Key Setup
225
  with gr.Row():
@@ -234,61 +244,66 @@ def create_interface():
234
  return processor, "✅ API key validated"
235
  except Exception as e:
236
  return None, f"❌ Error: {str(e)}"
237
-
238
  set_key_btn.click(fn=init_processor, inputs=api_key_input, outputs=[processor_state, status])
239
 
240
  # File Upload Tab
241
  with gr.Tab("Upload File"):
242
  with gr.Row():
243
- file_input = gr.File(label="Upload Image/PDF", file_types=SUPPORTED_IMAGE_TYPES + SUPPORTED_DOCUMENT_TYPES)
244
- file_preview = gr.Gallery(label="Preview", elem_classes="preview")
245
  file_output = gr.Markdown(label="OCR Result", elem_classes="output-markdown")
246
- file_raw_output = gr.Textbox(label="Source Path")
247
  file_button = gr.Button("Process", variant="primary")
248
 
249
- file_input.change(fn=update_file_preview, inputs=file_input, outputs=file_preview)
250
-
251
- def process_file_fn(p, f):
252
- if not p:
253
- return "## Set API key first", ""
254
- return p.process_file(f)
255
 
 
256
  file_button.click(
257
- fn=process_file_fn,
258
  inputs=[processor_state, file_input],
259
- outputs=[file_output, file_raw_output]
260
  )
261
 
262
  # URL Tab
263
  with gr.Tab("URL Input"):
264
  with gr.Row():
265
- url_input = gr.Textbox(label="URL to Image/PDF")
266
- url_preview = gr.Gallery(label="Preview", elem_classes="preview")
267
  url_output = gr.Markdown(label="OCR Result", elem_classes="output-markdown")
268
- url_raw_output = gr.Textbox(label="Source URL")
269
  url_button = gr.Button("Process", variant="primary")
270
 
271
- url_input.change(fn=update_url_preview, inputs=url_input, outputs=url_preview)
272
-
273
- def process_url_fn(p, u):
274
- if not p:
275
- return "## Set API key first", ""
276
- return p.process_url(u)
 
 
 
 
 
 
277
 
 
278
  url_button.click(
279
- fn=process_url_fn,
280
  inputs=[processor_state, url_input],
281
- outputs=[url_output, url_raw_output]
282
  )
283
 
 
284
  gr.Examples(
285
  examples=[],
286
  inputs=[file_input, url_input]
287
  )
 
288
  return demo
289
 
290
  if __name__ == "__main__":
291
  os.environ['START_TIME'] = time.strftime('%Y-%m-%d %H:%M:%S')
292
- print(f"===== Application Startup at {os.environ['START_TIME']} ===")
293
- demo = create_interface()
294
- demo.launch(share=True, max_threads=1)
 
6
  import time
7
  import pymupdf as fitz
8
  import logging
9
+ from mistralai import Mistral, ImageURLChunk
 
10
  from mistralai.models import OCRResponse
11
  from typing import Union, List, Tuple, Optional, Dict
12
  from tenacity import retry, stop_after_attempt, wait_exponential
13
+ from concurrent.futures import ThreadPoolExecutor
14
  import tempfile
15
 
16
  # Constants
17
+ SUPPORTED_IMAGE_TYPES = [".jpg", ".png", ".jpeg"]
18
+ SUPPORTED_PDF_TYPES = [".pdf"]
19
  UPLOAD_FOLDER = "./uploads"
20
  MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
21
+ MAX_PDF_PAGES = 50
22
 
23
  # Configuration
24
  os.makedirs(UPLOAD_FOLDER, exist_ok=True)
 
34
  if not api_key or not isinstance(api_key, str):
35
  raise ValueError("Valid API key must be provided")
36
  self.client = Mistral(api_key=api_key)
 
37
  self._validate_client()
38
 
39
  def _validate_client(self) -> None:
 
45
  raise ValueError(f"API key validation failed: {str(e)}")
46
 
47
  @staticmethod
48
+ def _check_file_size(file_input: Union[str, bytes]) -> None:
49
+ if isinstance(file_input, str) and os.path.exists(file_input):
50
+ size = os.path.getsize(file_input)
51
+ elif hasattr(file_input, 'read'):
52
+ size = len(file_input.read())
53
+ file_input.seek(0)
54
+ else:
55
+ size = len(file_input)
56
  if size > MAX_FILE_SIZE:
57
  raise ValueError(f"File size exceeds {MAX_FILE_SIZE/1024/1024}MB limit")
58
 
59
+ @staticmethod
60
+ def _save_uploaded_file(file_input: Union[str, bytes], filename: str) -> str:
61
+ clean_filename = os.path.basename(filename).replace(os.sep, "_")
62
+ file_path = os.path.join(UPLOAD_FOLDER, f"{int(time.time())}_{clean_filename}")
63
  try:
64
+ if isinstance(file_input, str) and file_input.startswith("http"):
65
+ response = requests.get(file_input, timeout=30)
66
+ response.raise_for_status()
67
+ with open(file_path, 'wb') as f:
68
+ f.write(response.content)
69
+ elif isinstance(file_input, str) and os.path.exists(file_input):
70
+ shutil.copy2(file_input, file_path)
71
+ else:
72
+ with open(file_path, 'wb') as f:
73
+ if hasattr(file_input, 'read'):
74
+ shutil.copyfileobj(file_input, f)
75
+ else:
76
+ f.write(file_input)
77
+ if not os.path.exists(file_path):
78
+ raise FileNotFoundError(f"Failed to save file at {file_path}")
79
+ return file_path
80
  except Exception as e:
81
+ logger.error(f"Error saving file {filename}: {str(e)}")
82
+ raise
83
 
84
  @staticmethod
85
+ def _encode_image(image_path: str) -> str:
86
+ try:
87
+ with open(image_path, "rb") as image_file:
88
+ return base64.b64encode(image_file.read()).decode('utf-8')
89
+ except Exception as e:
90
+ logger.error(f"Error encoding image {image_path}: {str(e)}")
91
+ raise ValueError(f"Failed to encode image: {str(e)}")
92
+
93
+ @staticmethod
94
+ def _pdf_to_images(pdf_path: str) -> List[Tuple[str, str]]:
95
  try:
96
  pdf_document = fitz.open(pdf_path)
97
+ if pdf_document.page_count > MAX_PDF_PAGES:
98
  pdf_document.close()
99
+ raise ValueError(f"PDF exceeds maximum page limit of {MAX_PDF_PAGES}")
100
+
101
+ with ThreadPoolExecutor() as executor:
102
+ image_data = list(executor.map(
103
+ lambda i: OCRProcessor._convert_page(pdf_path, i),
104
+ range(pdf_document.page_count)
105
+ ))
106
  pdf_document.close()
107
+ return [data for data in image_data if data]
108
  except Exception as e:
109
+ logger.error(f"Error converting PDF to images: {str(e)}")
110
+ return []
111
+
112
+ @staticmethod
113
+ def _convert_page(pdf_path: str, page_num: int) -> Tuple[str, str]:
114
+ try:
115
+ pdf_document = fitz.open(pdf_path)
116
+ page = pdf_document[page_num]
117
+ pix = page.get_pixmap(dpi=150)
118
+ image_path = os.path.join(UPLOAD_FOLDER, f"page_{page_num + 1}_{int(time.time())}.png")
119
+ pix.save(image_path)
120
+ encoded = OCRProcessor._encode_image(image_path)
121
+ pdf_document.close()
122
+ return image_path, encoded
123
+ except Exception as e:
124
+ logger.error(f"Error converting page {page_num}: {str(e)}")
125
+ return None, None
126
 
127
  @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10))
128
+ def _call_ocr_api(self, encoded_image: str) -> OCRResponse:
129
+ base64_url = f"data:image/png;base64,{encoded_image}"
130
  try:
131
  logger.info("Calling OCR API")
132
  response = self.client.ocr.process(
133
  model="mistral-ocr-latest",
134
+ document=ImageURLChunk(image_url=base64_url),
135
  include_image_base64=True
136
  )
137
  return response
 
139
  logger.error(f"OCR API call failed: {str(e)}")
140
  raise
141
 
142
+ def process_file(self, file: gr.File) -> Tuple[str, str, List[str]]:
143
  """Process uploaded file (image or PDF)."""
144
  if not file:
145
+ return "## No file provided", "", []
146
+
147
+ file_name = file.name
148
+ self._check_file_size(file)
149
+ file_path = self._save_uploaded_file(file, file_name)
150
+
151
+ if file_name.lower().endswith(tuple(SUPPORTED_IMAGE_TYPES)):
152
+ encoded_image = self._encode_image(file_path)
153
+ response = self._call_ocr_api(encoded_image)
154
+ markdown = self._combine_markdown(response)
155
+ return markdown, file_path, [file_path]
156
+
157
+ elif file_name.lower().endswith('.pdf'):
158
+ image_data = self._pdf_to_images(file_path)
159
+ if not image_data:
160
+ return "## No pages converted from PDF", file_path, []
161
+
162
+ ocr_results = []
163
+ image_paths = [path for path, _ in image_data]
164
+ for _, encoded in image_data:
165
+ response = self._call_ocr_api(encoded)
 
166
  markdown = self._combine_markdown(response)
167
+ ocr_results.append(markdown)
168
+ return "\n\n".join(ocr_results), file_path, image_paths
169
+
170
+ return "## Unsupported file type", file_path, []
 
 
171
 
172
+ def process_url(self, url: str) -> Tuple[str, str, List[str]]:
173
  """Process URL (image or PDF)."""
174
  if not url:
175
+ return "## No URL provided", "", []
176
+
177
+ file_name = url.split('/')[-1] or f"file_{int(time.time())}"
178
+ file_path = self._save_uploaded_file(url, file_name)
179
+
180
+ if file_name.lower().endswith(tuple(SUPPORTED_IMAGE_TYPES)):
181
+ encoded_image = self._encode_image(file_path)
182
+ response = self._call_ocr_api(encoded_image)
183
+ markdown = self._combine_markdown(response)
184
+ return markdown, url, [file_path]
185
+
186
+ elif file_name.lower().endswith('.pdf'):
187
+ image_data = self._pdf_to_images(file_path)
188
+ if not image_data:
189
+ return "## No pages converted from PDF", url, []
190
+
191
+ ocr_results = []
192
+ image_paths = [path for path, _ in image_data]
193
+ for _, encoded in image_data:
194
+ response = self._call_ocr_api(encoded)
195
  markdown = self._combine_markdown(response)
196
+ ocr_results.append(markdown)
197
+ return "\n\n".join(ocr_results), url, image_paths
198
+
199
+ return "## Unsupported URL content type", url, []
 
 
200
 
201
  @staticmethod
202
  def _combine_markdown(response: OCRResponse) -> str:
 
216
  markdown_parts.append(markdown)
217
  return "\n\n".join(markdown_parts) or "## No text detected"
218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  def create_interface():
220
  css = """
221
  .output-markdown {font-size: 14px; max-height: 500px; overflow-y: auto;}
222
  .status {color: #666; font-style: italic;}
223
  .preview {max-height: 300px;}
224
  """
225
+
226
  with gr.Blocks(title="Mistral OCR Demo", css=css) as demo:
227
  gr.Markdown("# Mistral OCR Demo")
228
+ gr.Markdown(f"""
229
+ Process PDFs and images (max {MAX_FILE_SIZE/1024/1024}MB, {MAX_PDF_PAGES} pages for PDFs) via upload or URL.
230
+ View previews and OCR results with embedded images.
231
+ Learn more at [Mistral OCR](https://mistral.ai/news/mistral-ocr).
232
+ """)
233
 
234
  # API Key Setup
235
  with gr.Row():
 
244
  return processor, "✅ API key validated"
245
  except Exception as e:
246
  return None, f"❌ Error: {str(e)}"
247
+
248
  set_key_btn.click(fn=init_processor, inputs=api_key_input, outputs=[processor_state, status])
249
 
250
  # File Upload Tab
251
  with gr.Tab("Upload File"):
252
  with gr.Row():
253
+ file_input = gr.File(label="Upload PDF/Image", file_types=SUPPORTED_IMAGE_TYPES + SUPPORTED_PDF_TYPES)
254
+ file_preview = gr.Gallery(label="Preview", elem_classes="preview")
255
  file_output = gr.Markdown(label="OCR Result", elem_classes="output-markdown")
256
+ file_raw_output = gr.Textbox(label="Raw File Path")
257
  file_button = gr.Button("Process", variant="primary")
258
 
259
+ def update_file_preview(file):
260
+ return [file.name] if file else []
 
 
 
 
261
 
262
+ file_input.change(fn=update_file_preview, inputs=file_input, outputs=file_preview)
263
  file_button.click(
264
+ fn=lambda p, f: p.process_file(f) if p else ("## Set API key first", "", []),
265
  inputs=[processor_state, file_input],
266
+ outputs=[file_output, file_raw_output, file_preview]
267
  )
268
 
269
  # URL Tab
270
  with gr.Tab("URL Input"):
271
  with gr.Row():
272
+ url_input = gr.Textbox(label="URL to PDF/Image")
273
+ url_preview = gr.Gallery(label="Preview", elem_classes="preview")
274
  url_output = gr.Markdown(label="OCR Result", elem_classes="output-markdown")
275
+ url_raw_output = gr.Textbox(label="Raw URL")
276
  url_button = gr.Button("Process", variant="primary")
277
 
278
+ def update_url_preview(url):
279
+ if not url:
280
+ return []
281
+ try:
282
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.tmp')
283
+ response = requests.get(url, timeout=10)
284
+ temp_file.write(response.content)
285
+ temp_file.close()
286
+ return [temp_file.name]
287
+ except Exception as e:
288
+ logger.error(f"URL preview error: {str(e)}")
289
+ return []
290
 
291
+ url_input.change(fn=update_url_preview, inputs=url_input, outputs=url_preview)
292
  url_button.click(
293
+ fn=lambda p, u: p.process_url(u) if p else ("## Set API key first", "", []),
294
  inputs=[processor_state, url_input],
295
+ outputs=[url_output, url_raw_output, url_preview]
296
  )
297
 
298
+ # Examples
299
  gr.Examples(
300
  examples=[],
301
  inputs=[file_input, url_input]
302
  )
303
+
304
  return demo
305
 
306
  if __name__ == "__main__":
307
  os.environ['START_TIME'] = time.strftime('%Y-%m-%d %H:%M:%S')
308
+ print(f"===== Application Startup at {os.environ['START_TIME']} =====")
309
+ create_interface().launch(share=True, max_threads=1)