Update main.py
Browse files
main.py
CHANGED
|
@@ -28,7 +28,6 @@ IMAGE_GEN_API_URL = "https://www.chatwithmono.xyz/api/image"
|
|
| 28 |
MODERATION_API_URL = "https://www.chatwithmono.xyz/api/moderation"
|
| 29 |
|
| 30 |
# --- Model Definitions ---
|
| 31 |
-
# Added florence-2-ocr for the new endpoint
|
| 32 |
AVAILABLE_MODELS = [
|
| 33 |
{"id": "gpt-4-turbo", "object": "model", "created": int(time.time()), "owned_by": "system"},
|
| 34 |
{"id": "gpt-4o", "object": "model", "created": int(time.time()), "owned_by": "system"},
|
|
@@ -43,10 +42,9 @@ MODEL_ALIASES = {}
|
|
| 43 |
app = FastAPI(
|
| 44 |
title="OpenAI Compatible API",
|
| 45 |
description="An adapter for various services to be compatible with the OpenAI API specification.",
|
| 46 |
-
version="1.1.
|
| 47 |
)
|
| 48 |
|
| 49 |
-
# Initialize Gradio client for OCR globally to avoid re-initialization on each request
|
| 50 |
try:
|
| 51 |
ocr_client = Client("multimodalart/Florence-2-l4")
|
| 52 |
except Exception as e:
|
|
@@ -54,8 +52,7 @@ except Exception as e:
|
|
| 54 |
ocr_client = None
|
| 55 |
|
| 56 |
# --- Pydantic Models ---
|
| 57 |
-
|
| 58 |
-
# /v1/chat/completions
|
| 59 |
class Message(BaseModel):
|
| 60 |
role: str
|
| 61 |
content: str
|
|
@@ -66,7 +63,6 @@ class ChatRequest(BaseModel):
|
|
| 66 |
stream: Optional[bool] = False
|
| 67 |
tools: Optional[Any] = None
|
| 68 |
|
| 69 |
-
# /v1/images/generations
|
| 70 |
class ImageGenerationRequest(BaseModel):
|
| 71 |
prompt: str
|
| 72 |
aspect_ratio: Optional[str] = "1:1"
|
|
@@ -74,12 +70,10 @@ class ImageGenerationRequest(BaseModel):
|
|
| 74 |
user: Optional[str] = None
|
| 75 |
model: Optional[str] = "default"
|
| 76 |
|
| 77 |
-
# /v1/moderations
|
| 78 |
class ModerationRequest(BaseModel):
|
| 79 |
input: Union[str, List[str]]
|
| 80 |
model: Optional[str] = "text-moderation-stable"
|
| 81 |
|
| 82 |
-
# /v1/ocr
|
| 83 |
class OcrRequest(BaseModel):
|
| 84 |
image_url: Optional[str] = Field(None, description="URL of the image to process.")
|
| 85 |
image_b64: Optional[str] = Field(None, description="Base64 encoded string of the image to process.")
|
|
@@ -88,11 +82,9 @@ class OcrRequest(BaseModel):
|
|
| 88 |
@classmethod
|
| 89 |
def check_sources(cls, data: Any) -> Any:
|
| 90 |
if isinstance(data, dict):
|
| 91 |
-
|
| 92 |
-
b64 = data.get('image_b64')
|
| 93 |
-
if not (url or b64):
|
| 94 |
raise ValueError('Either image_url or image_b64 must be provided.')
|
| 95 |
-
if
|
| 96 |
raise ValueError('Provide either image_url or image_b64, not both.')
|
| 97 |
return data
|
| 98 |
|
|
@@ -100,10 +92,8 @@ class OcrResponse(BaseModel):
|
|
| 100 |
ocr_text: str
|
| 101 |
raw_response: dict
|
| 102 |
|
| 103 |
-
|
| 104 |
-
# --- Helper Function for Random ID Generation ---
|
| 105 |
def generate_random_id(prefix: str, length: int = 29) -> str:
|
| 106 |
-
"""Generates a cryptographically secure, random alphanumeric ID."""
|
| 107 |
population = string.ascii_letters + string.digits
|
| 108 |
random_part = "".join(secrets.choice(population) for _ in range(length))
|
| 109 |
return f"{prefix}{random_part}"
|
|
@@ -115,6 +105,7 @@ async def list_models():
|
|
| 115 |
"""Lists the available models."""
|
| 116 |
return {"object": "list", "data": AVAILABLE_MODELS}
|
| 117 |
|
|
|
|
| 118 |
@app.post("/v1/chat/completions", tags=["Chat"])
|
| 119 |
async def chat_completion(request: ChatRequest):
|
| 120 |
"""Handles chat completion requests, supporting streaming and non-streaming."""
|
|
@@ -128,7 +119,6 @@ async def chat_completion(request: ChatRequest):
|
|
| 128 |
'user-agent': 'Mozilla/5.0',
|
| 129 |
}
|
| 130 |
|
| 131 |
-
# Handle tool prompting
|
| 132 |
if request.tools:
|
| 133 |
tool_prompt = f"""You have access to the following tools. To call a tool, please respond with JSON for a tool call within <tool_call></tool_call> XML tags. Respond in the format {{"name": tool name, "parameters": dictionary of argument name and its value}}. Do not use variables.
|
| 134 |
Tools: {";".join(f"<tool>{tool}</tool>" for tool in request.tools)}
|
|
@@ -181,7 +171,6 @@ Response Format for tool call:
|
|
| 181 |
|
| 182 |
in_tool_call = False
|
| 183 |
tool_call_buffer = ""
|
| 184 |
-
# Process text that might come after the tool call in the same chunk
|
| 185 |
remaining_text = current_buffer.split("</tool_call>", 1)[1]
|
| 186 |
if remaining_text:
|
| 187 |
content_piece = remaining_text
|
|
@@ -191,16 +180,14 @@ Response Format for tool call:
|
|
| 191 |
if "<tool_call>" in content_piece:
|
| 192 |
in_tool_call = True
|
| 193 |
tool_call_buffer += content_piece.split("<tool_call>", 1)[1]
|
| 194 |
-
# Process text that came before the tool call
|
| 195 |
text_before = content_piece.split("<tool_call>", 1)[0]
|
| 196 |
if text_before:
|
| 197 |
-
# Send the text before the tool call starts
|
| 198 |
delta = {"content": text_before, "tool_calls": None}
|
| 199 |
chunk = {"id": chat_id, "object": "chat.completion.chunk", "created": created, "model": model_id,
|
| 200 |
"choices": [{"index": 0, "delta": delta, "finish_reason": None}], "usage": None}
|
| 201 |
yield f"data: {json.dumps(chunk)}\n\n"
|
| 202 |
if "</tool_call>" not in tool_call_buffer:
|
| 203 |
-
continue
|
| 204 |
|
| 205 |
if not in_tool_call:
|
| 206 |
delta = {"content": content_piece}
|
|
@@ -217,7 +204,6 @@ Response Format for tool call:
|
|
| 217 |
except (json.JSONDecodeError, AttributeError): pass
|
| 218 |
break
|
| 219 |
|
| 220 |
-
# Finalize
|
| 221 |
final_usage = None
|
| 222 |
if usage_info:
|
| 223 |
final_usage = {"prompt_tokens": usage_info.get("promptTokens", 0), "completion_tokens": usage_info.get("completionTokens", 0), "total_tokens": usage_info.get("promptTokens", 0) + usage_info.get("completionTokens", 0)}
|
|
@@ -232,7 +218,7 @@ Response Format for tool call:
|
|
| 232 |
yield "data: [DONE]\n\n"
|
| 233 |
|
| 234 |
return StreamingResponse(event_stream(), media_type="text/event-stream")
|
| 235 |
-
else:
|
| 236 |
full_response, usage_info = "", {}
|
| 237 |
try:
|
| 238 |
async with httpx.AsyncClient(timeout=120) as client:
|
|
@@ -300,6 +286,8 @@ async def generate_images(request: ImageGenerationRequest):
|
|
| 300 |
return JSONResponse(status_code=500, content={"error": "An internal error occurred.", "details": str(e)})
|
| 301 |
return {"created": int(time.time()), "data": results}
|
| 302 |
|
|
|
|
|
|
|
| 303 |
@app.post("/v1/ocr", response_model=OcrResponse, tags=["OCR"])
|
| 304 |
async def perform_ocr(request: OcrRequest):
|
| 305 |
"""
|
|
@@ -322,16 +310,40 @@ async def perform_ocr(request: OcrRequest):
|
|
| 322 |
|
| 323 |
prediction = ocr_client.predict(image=handle_file(image_path), task_prompt="OCR", api_name="/process_image")
|
| 324 |
|
| 325 |
-
if not prediction or not isinstance(prediction, tuple):
|
| 326 |
-
raise HTTPException(status_code=502, detail="Invalid response from OCR service.")
|
| 327 |
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
except Exception as e:
|
|
|
|
|
|
|
|
|
|
| 332 |
raise HTTPException(status_code=500, detail=f"An error occurred during OCR processing: {str(e)}")
|
| 333 |
finally:
|
| 334 |
-
if temp_file_path:
|
| 335 |
os.unlink(temp_file_path)
|
| 336 |
|
| 337 |
@app.post("/v1/moderations", tags=["Moderation"])
|
|
|
|
| 28 |
MODERATION_API_URL = "https://www.chatwithmono.xyz/api/moderation"
|
| 29 |
|
| 30 |
# --- Model Definitions ---
|
|
|
|
| 31 |
AVAILABLE_MODELS = [
|
| 32 |
{"id": "gpt-4-turbo", "object": "model", "created": int(time.time()), "owned_by": "system"},
|
| 33 |
{"id": "gpt-4o", "object": "model", "created": int(time.time()), "owned_by": "system"},
|
|
|
|
| 42 |
app = FastAPI(
|
| 43 |
title="OpenAI Compatible API",
|
| 44 |
description="An adapter for various services to be compatible with the OpenAI API specification.",
|
| 45 |
+
version="1.1.1" # Incremented version for the fix
|
| 46 |
)
|
| 47 |
|
|
|
|
| 48 |
try:
|
| 49 |
ocr_client = Client("multimodalart/Florence-2-l4")
|
| 50 |
except Exception as e:
|
|
|
|
| 52 |
ocr_client = None
|
| 53 |
|
| 54 |
# --- Pydantic Models ---
|
| 55 |
+
# (Pydantic models are unchanged and remain the same as before)
|
|
|
|
| 56 |
class Message(BaseModel):
|
| 57 |
role: str
|
| 58 |
content: str
|
|
|
|
| 63 |
stream: Optional[bool] = False
|
| 64 |
tools: Optional[Any] = None
|
| 65 |
|
|
|
|
| 66 |
class ImageGenerationRequest(BaseModel):
|
| 67 |
prompt: str
|
| 68 |
aspect_ratio: Optional[str] = "1:1"
|
|
|
|
| 70 |
user: Optional[str] = None
|
| 71 |
model: Optional[str] = "default"
|
| 72 |
|
|
|
|
| 73 |
class ModerationRequest(BaseModel):
|
| 74 |
input: Union[str, List[str]]
|
| 75 |
model: Optional[str] = "text-moderation-stable"
|
| 76 |
|
|
|
|
| 77 |
class OcrRequest(BaseModel):
|
| 78 |
image_url: Optional[str] = Field(None, description="URL of the image to process.")
|
| 79 |
image_b64: Optional[str] = Field(None, description="Base64 encoded string of the image to process.")
|
|
|
|
| 82 |
@classmethod
|
| 83 |
def check_sources(cls, data: Any) -> Any:
|
| 84 |
if isinstance(data, dict):
|
| 85 |
+
if not (data.get('image_url') or data.get('image_b64')):
|
|
|
|
|
|
|
| 86 |
raise ValueError('Either image_url or image_b64 must be provided.')
|
| 87 |
+
if data.get('image_url') and data.get('image_b64'):
|
| 88 |
raise ValueError('Provide either image_url or image_b64, not both.')
|
| 89 |
return data
|
| 90 |
|
|
|
|
| 92 |
ocr_text: str
|
| 93 |
raw_response: dict
|
| 94 |
|
| 95 |
+
# --- Helper Function ---
|
|
|
|
| 96 |
def generate_random_id(prefix: str, length: int = 29) -> str:
|
|
|
|
| 97 |
population = string.ascii_letters + string.digits
|
| 98 |
random_part = "".join(secrets.choice(population) for _ in range(length))
|
| 99 |
return f"{prefix}{random_part}"
|
|
|
|
| 105 |
"""Lists the available models."""
|
| 106 |
return {"object": "list", "data": AVAILABLE_MODELS}
|
| 107 |
|
| 108 |
+
# (Chat, Image Generation, and Moderation endpoints are unchanged)
|
| 109 |
@app.post("/v1/chat/completions", tags=["Chat"])
|
| 110 |
async def chat_completion(request: ChatRequest):
|
| 111 |
"""Handles chat completion requests, supporting streaming and non-streaming."""
|
|
|
|
| 119 |
'user-agent': 'Mozilla/5.0',
|
| 120 |
}
|
| 121 |
|
|
|
|
| 122 |
if request.tools:
|
| 123 |
tool_prompt = f"""You have access to the following tools. To call a tool, please respond with JSON for a tool call within <tool_call></tool_call> XML tags. Respond in the format {{"name": tool name, "parameters": dictionary of argument name and its value}}. Do not use variables.
|
| 124 |
Tools: {";".join(f"<tool>{tool}</tool>" for tool in request.tools)}
|
|
|
|
| 171 |
|
| 172 |
in_tool_call = False
|
| 173 |
tool_call_buffer = ""
|
|
|
|
| 174 |
remaining_text = current_buffer.split("</tool_call>", 1)[1]
|
| 175 |
if remaining_text:
|
| 176 |
content_piece = remaining_text
|
|
|
|
| 180 |
if "<tool_call>" in content_piece:
|
| 181 |
in_tool_call = True
|
| 182 |
tool_call_buffer += content_piece.split("<tool_call>", 1)[1]
|
|
|
|
| 183 |
text_before = content_piece.split("<tool_call>", 1)[0]
|
| 184 |
if text_before:
|
|
|
|
| 185 |
delta = {"content": text_before, "tool_calls": None}
|
| 186 |
chunk = {"id": chat_id, "object": "chat.completion.chunk", "created": created, "model": model_id,
|
| 187 |
"choices": [{"index": 0, "delta": delta, "finish_reason": None}], "usage": None}
|
| 188 |
yield f"data: {json.dumps(chunk)}\n\n"
|
| 189 |
if "</tool_call>" not in tool_call_buffer:
|
| 190 |
+
continue
|
| 191 |
|
| 192 |
if not in_tool_call:
|
| 193 |
delta = {"content": content_piece}
|
|
|
|
| 204 |
except (json.JSONDecodeError, AttributeError): pass
|
| 205 |
break
|
| 206 |
|
|
|
|
| 207 |
final_usage = None
|
| 208 |
if usage_info:
|
| 209 |
final_usage = {"prompt_tokens": usage_info.get("promptTokens", 0), "completion_tokens": usage_info.get("completionTokens", 0), "total_tokens": usage_info.get("promptTokens", 0) + usage_info.get("completionTokens", 0)}
|
|
|
|
| 218 |
yield "data: [DONE]\n\n"
|
| 219 |
|
| 220 |
return StreamingResponse(event_stream(), media_type="text/event-stream")
|
| 221 |
+
else:
|
| 222 |
full_response, usage_info = "", {}
|
| 223 |
try:
|
| 224 |
async with httpx.AsyncClient(timeout=120) as client:
|
|
|
|
| 286 |
return JSONResponse(status_code=500, content={"error": "An internal error occurred.", "details": str(e)})
|
| 287 |
return {"created": int(time.time()), "data": results}
|
| 288 |
|
| 289 |
+
|
| 290 |
+
# === FIXED OCR Endpoint ===
|
| 291 |
@app.post("/v1/ocr", response_model=OcrResponse, tags=["OCR"])
|
| 292 |
async def perform_ocr(request: OcrRequest):
|
| 293 |
"""
|
|
|
|
| 310 |
|
| 311 |
prediction = ocr_client.predict(image=handle_file(image_path), task_prompt="OCR", api_name="/process_image")
|
| 312 |
|
| 313 |
+
if not prediction or not isinstance(prediction, tuple) or len(prediction) == 0:
|
| 314 |
+
raise HTTPException(status_code=502, detail="Invalid or empty response from OCR service.")
|
| 315 |
|
| 316 |
+
raw_output = prediction[0]
|
| 317 |
+
raw_result_dict = {}
|
| 318 |
+
|
| 319 |
+
# --- START: FIX ---
|
| 320 |
+
# The Gradio client returns a JSON string, not a dict. We must parse it.
|
| 321 |
+
if isinstance(raw_output, str):
|
| 322 |
+
try:
|
| 323 |
+
raw_result_dict = json.loads(raw_output)
|
| 324 |
+
except json.JSONDecodeError:
|
| 325 |
+
raise HTTPException(status_code=502, detail="Failed to parse JSON response from OCR service.")
|
| 326 |
+
elif isinstance(raw_output, dict):
|
| 327 |
+
# If it's already a dict, use it directly
|
| 328 |
+
raw_result_dict = raw_output
|
| 329 |
+
else:
|
| 330 |
+
raise HTTPException(status_code=502, detail=f"Unexpected data type from OCR service: {type(raw_output)}")
|
| 331 |
+
# --- END: FIX ---
|
| 332 |
+
|
| 333 |
+
ocr_text = raw_result_dict.get("OCR", "")
|
| 334 |
+
# Fallback in case the OCR key is missing but there's other data
|
| 335 |
+
if not ocr_text:
|
| 336 |
+
ocr_text = str(raw_result_dict)
|
| 337 |
+
|
| 338 |
+
return OcrResponse(ocr_text=ocr_text, raw_response=raw_result_dict)
|
| 339 |
+
|
| 340 |
except Exception as e:
|
| 341 |
+
# Catch the specific HTTPException and re-raise it, otherwise wrap other exceptions
|
| 342 |
+
if isinstance(e, HTTPException):
|
| 343 |
+
raise e
|
| 344 |
raise HTTPException(status_code=500, detail=f"An error occurred during OCR processing: {str(e)}")
|
| 345 |
finally:
|
| 346 |
+
if temp_file_path and os.path.exists(temp_file_path):
|
| 347 |
os.unlink(temp_file_path)
|
| 348 |
|
| 349 |
@app.post("/v1/moderations", tags=["Moderation"])
|