import gradio as gr import base64 from PIL import Image from io import BytesIO import os from tryon_api_client import TryOnAPIClient, TryOnAPIError def process_tryon(rapid_key, rapid_host, person_img, garment_img, garment_type): print(f"DEBUG INPUT: key='{rapid_key}', host='{rapid_host}'") if rapid_key: rapid_key = rapid_key.strip() if rapid_host: rapid_host = rapid_host.strip() # Determine API URL api_url = os.environ.get("API_BASE_URL", "http://localhost:5001") if not rapid_key: raise gr.Error("⚠️ Please enter your X-RapidAPI-Key.") if not rapid_host: raise gr.Error("⚠️ Please enter your X-RapidAPI-Host.") if rapid_host: api_url = f"https://{rapid_host}" client = TryOnAPIClient(base_url=api_url, rapid_api_key=rapid_key, rapid_api_host=rapid_host) # Validate inputs if not person_img or not garment_img: raise gr.Error("⚠️ Please upload both Person and Garment images.") # Validate file formats try: valid_extensions = ('.png', '.jpg', '.jpeg') if not person_img.lower().endswith(valid_extensions): raise gr.Error("⚠️ Person image must be PNG, JPEG, or JPG format.") if not garment_img.lower().endswith(valid_extensions): raise gr.Error("⚠️ Garment image must be PNG, JPEG, or JPG format.") except AttributeError: raise gr.Error("⚠️ Invalid image file. Please upload valid image files.") try: # 1. Create Task try: task_data = client.create_task(person_img, garment_img, garment_type) except TryOnAPIError as e: raise gr.Error(f"❌ API Error {e.status_code}: {e.message}") except ConnectionError: raise gr.Error("🔌 Cannot connect to the Try-On service. Please check if the API server is running.") except TimeoutError: raise gr.Error("⏱️ Connection timeout. The server is taking too long to respond.") except Exception as e: raise gr.Error(f"❌ Failed to create try-on task: {str(e)}") task_id = task_data.get('task_id') if not task_id: raise gr.Error("❌ No task ID returned from API. Please try again.") # 2. Wait for Completion using Client logic try: result = client.wait_for_completion(task_id) except TimeoutError: raise gr.Error("⏱️ Try-on process is taking too long. Please try again with smaller images.") except Exception as e: raise gr.Error(f"❌ Failed while processing try-on: {str(e)}") # 3. Process result image_base64 = result.get('image_base64') if image_base64: try: return base64_to_image(image_base64) except Exception as e: raise gr.Error(f"❌ Failed to decode result image: {str(e)}") else: raise gr.Error("❌ Try-on completed but no image was generated. Please try again.") except gr.Error: # Re-raise Gradio errors as-is raise except Exception as e: # Catch any unexpected errors raise gr.Error(f"❌ Unexpected error occurred: {str(e)}") def base64_to_image(base64_string): try: image_data = base64.b64decode(base64_string) image = Image.open(BytesIO(image_data)) return image except base64.binascii.Error: raise ValueError("Invalid base64 string format") except Exception as e: raise ValueError(f"Cannot decode image: {str(e)}") # UI Layout with gr.Blocks(title="Virtual Try-On") as app: gr.Markdown("# 👕 Virtual Try-On Demo") gr.Markdown("Upload a person image and a garment image to see the magic") gr.HTML( """
Check out the API @ RapidAPI.com
""" ) with gr.Row(): gr.Markdown("### Before - After Example") with gr.Row(): gr.Image(value="images /models/try-on-model.png", label="Model", interactive=False, height=300) gr.Image(value="images /garment/try-on-garment.png", label="Garment - Full Set", interactive=False, height=300) gr.Image(value="images /results/after-try-on.jpg", label="Result", interactive=False, height=300) with gr.Row(): gr.Markdown("### Try-On Example") with gr.Row(): with gr.Column(): with gr.Accordion("ℹ️ How to get your API Credentials", open=False): gr.Markdown(""" 1. Go to **[Virtual Outfit Try On API on RapidAPI](https://rapidapi.com/aiproviderlabs/api/virtual-outfit-try-on-api)**. 2. Subscribe to a plan (there is a free tier). 3. Go to the **Endpoints** tab. 4. On the right side, look for **Header Parameters**: - `X-RapidAPI-Key` is your generic RapidAPI Key. - `X-RapidAPI-Host` is `virtual-outfit-try-on-api.p.rapidapi.com` (or similar). 5. Copy these values into the fields below. """) with gr.Group(): rapid_key_input = gr.Textbox( label="X-RapidAPI-Key", type="password", value=os.environ.get("X_RAPIDAPI_KEY", ""), placeholder="Enter RapidAPI Key" ) rapid_host_input = gr.Textbox( label="X-RapidAPI-Host", value=os.environ.get("X_RAPIDAPI_HOST", ""), placeholder="Enter RapidAPI Host" ) person_input = gr.Image( type="filepath", label="Person Image (PNG, JPEG, JPG only)", height=400 ) gr.Examples( examples=[ "images /models/try-on-model.png", "images /models/try-on-model-2.png", "images /models/try-on-model-3.png", "images /models/try-on-model-4.png", ], inputs=person_input ) garment_input = gr.Image( type="filepath", label="Garment Image (PNG, JPEG, JPG only)", height=400 ) gr.Examples( examples=[ "images /garment/try-on-garment.png", "images /garment/try-on-garment-2.png", "images /garment/try-on-garment-3.jpg", "images /garment/try-on-garment-4.jpg", ], inputs=garment_input ) garment_type_input = gr.Radio( choices=["top", "bottom", "fullset"], value="top", label="Garment Type", info="Select the type of clothing you are trying on." ) run_btn = gr.Button("🚀 Start Try-On", variant="primary") with gr.Column(): result_output = gr.Image(label="Try-On Result", height=400) run_btn.click( fn=process_tryon, inputs=[rapid_key_input, rapid_host_input, person_input, garment_input, garment_type_input], outputs=result_output ) if __name__ == "__main__": app.launch(server_name="0.0.0.0", server_port=7860)