Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| HunyuanOCR Model Wrapper | |
| Provides an easy-to-use interface for text detection and recognition | |
| """ | |
| import re | |
| import os | |
| import torch | |
| from typing import Dict, List, Tuple, Optional | |
| from PIL import Image | |
| from transformers import AutoProcessor, HunYuanVLForConditionalGeneration | |
| from transformers.modeling_outputs import CausalLMOutputWithPast | |
| import requests | |
| from io import BytesIO | |
| # Monkey-patch HunYuanVLForConditionalGeneration.generate to fix dtype issue | |
| def patched_generate( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| imgs: Optional[list[torch.FloatTensor]] = None, | |
| imgs_pos: Optional[list[int]] = None, | |
| token_type_ids: Optional[torch.LongTensor] = None, | |
| pixel_values: Optional[torch.FloatTensor] = None, | |
| image_grid_thw: Optional[list[int]] = None, | |
| **kwargs, | |
| ) -> CausalLMOutputWithPast: | |
| if "inputs_embeds" in kwargs: | |
| raise NotImplementedError("`inputs_embeds` is not supported") | |
| inputs_embeds = self.model.embed_tokens(input_ids) | |
| if self.vit is not None and pixel_values is not None: | |
| # PATCH: Use model's dtype instead of forcing bfloat16 | |
| pixel_values = pixel_values.to(self.dtype) | |
| image_embeds = self.vit(pixel_values, image_grid_thw) | |
| # ViT may be deployed on different GPUs from those used by LLMs, due to auto-mapping of accelerate. | |
| image_embeds = image_embeds.to(input_ids.device, non_blocking=True) | |
| image_mask, _ = self.get_placeholder_mask( | |
| input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds | |
| ) | |
| inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) | |
| return super(HunYuanVLForConditionalGeneration, self).generate( | |
| inputs=input_ids, | |
| position_ids=position_ids, | |
| attention_mask=attention_mask, | |
| inputs_embeds=inputs_embeds, | |
| **kwargs, | |
| ) | |
| HunYuanVLForConditionalGeneration.generate = patched_generate | |
| class HunyuanOCR: | |
| """Wrapper class for HunyuanOCR model for text spotting tasks""" | |
| def __init__(self, model_path: str = "tencent/HunyuanOCR", device: Optional[str] = None): | |
| """ | |
| Initialize the HunyuanOCR model | |
| Args: | |
| model_path: Path or name of the model (default: "tencent/HunyuanOCR") | |
| device: Device to load model on (cuda/cpu). Auto-detected if None. | |
| """ | |
| # Check if local model exists when using default path | |
| if model_path == "tencent/HunyuanOCR" and os.path.exists("HunyuanOCR"): | |
| print("Found local HunyuanOCR model, using it instead of downloading...") | |
| model_path = "HunyuanOCR" | |
| self.model_path = model_path | |
| # Auto-detect device if not specified | |
| if device is None: | |
| if torch.cuda.is_available(): | |
| self.device = "cuda" | |
| elif torch.backends.mps.is_available(): | |
| self.device = "mps" | |
| else: | |
| self.device = "cpu" | |
| else: | |
| self.device = device | |
| print(f"Loading HunyuanOCR model on {self.device}...") | |
| # Load processor | |
| self.processor = AutoProcessor.from_pretrained(model_path, use_fast=False) | |
| # Determine dtype based on device | |
| if self.device == "cuda": | |
| torch_dtype = torch.bfloat16 | |
| elif self.device == "mps": | |
| torch_dtype = torch.float16 | |
| else: | |
| torch_dtype = torch.float32 | |
| # Load model | |
| self.model = HunYuanVLForConditionalGeneration.from_pretrained( | |
| model_path, | |
| attn_implementation="eager", | |
| torch_dtype=torch_dtype, | |
| device_map="auto" if self.device == "cuda" else None | |
| ) | |
| if self.device != "cuda": | |
| self.model = self.model.to(self.device) | |
| print("Model loaded successfully!") | |
| def clean_repeated_substrings(self, text: str) -> str: | |
| """ | |
| Clean repeated substrings in text output | |
| Args: | |
| text: Input text to clean | |
| Returns: | |
| Cleaned text | |
| """ | |
| n = len(text) | |
| if n < 8000: | |
| return text | |
| for length in range(2, n // 10 + 1): | |
| candidate = text[-length:] | |
| count = 0 | |
| i = n - length | |
| while i >= 0 and text[i:i + length] == candidate: | |
| count += 1 | |
| i -= length | |
| if count >= 10: | |
| return text[:n - length * (count - 1)] | |
| return text | |
| def load_image(self, image_source: str) -> Image.Image: | |
| """ | |
| Load image from URL or file path | |
| Args: | |
| image_source: URL or file path to image | |
| Returns: | |
| PIL Image object | |
| """ | |
| if image_source.startswith(('http://', 'https://')): | |
| response = requests.get(image_source) | |
| response.raise_for_status() | |
| return Image.open(BytesIO(response.content)) | |
| else: | |
| return Image.open(image_source) | |
| def detect_text(self, image: Image.Image, prompt: Optional[str] = None) -> str: | |
| """ | |
| Detect and recognize text in image with bounding boxes | |
| Args: | |
| image: PIL Image object | |
| prompt: Custom prompt (default: text spotting prompt in Chinese) | |
| Returns: | |
| Model response with detected text and coordinates | |
| """ | |
| # Default prompt for text spotting | |
| if prompt is None: | |
| prompt = "检测并识别图片中的文字,将文本内容与坐标格式化输出。" | |
| # Prepare messages | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image"}, | |
| {"type": "text", "text": prompt}, | |
| ], | |
| } | |
| ] | |
| # Apply chat template | |
| text = self.processor.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| # Process inputs | |
| inputs = self.processor( | |
| text=[text], | |
| images=[image], | |
| padding=True, | |
| return_tensors="pt", | |
| ) | |
| # Generate | |
| with torch.no_grad(): | |
| # Get model's dtype | |
| model_dtype = next(self.model.parameters()).dtype | |
| if self.device == "cuda": | |
| device = next(self.model.parameters()).device | |
| inputs = inputs.to(device) | |
| else: | |
| # Move to device and cast floating point tensors to model's dtype | |
| new_inputs = {} | |
| for k, v in inputs.items(): | |
| if torch.is_tensor(v): | |
| v = v.to(self.device) | |
| if v.dtype in [torch.float16, torch.bfloat16, torch.float32]: | |
| v = v.to(dtype=model_dtype) | |
| new_inputs[k] = v | |
| else: | |
| new_inputs[k] = v | |
| inputs = new_inputs | |
| generated_ids = self.model.generate( | |
| **inputs, | |
| max_new_tokens=2048, | |
| do_sample=False | |
| ) | |
| # Decode output | |
| if "input_ids" in inputs: | |
| input_ids = inputs["input_ids"] | |
| else: | |
| input_ids = inputs["inputs"] | |
| generated_ids_trimmed = [ | |
| out_ids[len(in_ids):] | |
| for in_ids, out_ids in zip(input_ids, generated_ids) | |
| ] | |
| output_text = self.processor.batch_decode( | |
| generated_ids_trimmed, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False | |
| )[0] | |
| # Clean repeated substrings | |
| output_text = self.clean_repeated_substrings(output_text) | |
| return output_text | |
| def parse_detection_results(self, response: str, image_width: int, image_height: int) -> List[Dict]: | |
| """ | |
| Parse detection response into structured format with denormalized coordinates | |
| Args: | |
| response: Model output text | |
| image_width: Image width in pixels | |
| image_height: Image height in pixels | |
| Returns: | |
| List of dictionaries with 'text', 'x1', 'y1', 'x2', 'y2' keys | |
| """ | |
| results = [] | |
| # Pattern to match text and coordinates: text(x1,y1),(x2,y2) | |
| pattern = r'([^()]+?)(\(\d+,\d+\),\(\d+,\d+\))' | |
| matches = re.finditer(pattern, response) | |
| for match in matches: | |
| try: | |
| text = match.group(1).strip() | |
| coords = match.group(2) | |
| # Parse coordinates | |
| coord_pattern = r'\((\d+),(\d+)\)' | |
| coord_matches = re.findall(coord_pattern, coords) | |
| if len(coord_matches) == 2: | |
| # Coordinates are normalized to [0, 1000], denormalize them | |
| x1_norm, y1_norm = float(coord_matches[0][0]), float(coord_matches[0][1]) | |
| x2_norm, y2_norm = float(coord_matches[1][0]), float(coord_matches[1][1]) | |
| # Denormalize to image dimensions | |
| x1 = int(x1_norm * image_width / 1000) | |
| y1 = int(y1_norm * image_height / 1000) | |
| x2 = int(x2_norm * image_width / 1000) | |
| y2 = int(y2_norm * image_height / 1000) | |
| results.append({ | |
| 'text': text, | |
| 'x1': x1, | |
| 'y1': y1, | |
| 'x2': x2, | |
| 'y2': y2 | |
| }) | |
| except Exception as e: | |
| print(f"Error parsing detection result: {str(e)}") | |
| continue | |
| return results | |
| def process_image(self, image_source: str, prompt: Optional[str] = None) -> Tuple[str, List[Dict]]: | |
| """ | |
| Complete pipeline: load image, detect text, parse results | |
| Args: | |
| image_source: Path or URL to image | |
| prompt: Custom prompt for detection | |
| Returns: | |
| Tuple of (raw_response, parsed_results) | |
| """ | |
| # Load image | |
| image = self.load_image(image_source) | |
| image_width, image_height = image.size | |
| # Detect text | |
| response = self.detect_text(image, prompt) | |
| # Parse results | |
| parsed_results = self.parse_detection_results(response, image_width, image_height) | |
| return response, parsed_results, image | |