Mohaddz's picture
Update app.py
26c944c verified
import io
import os
import spaces
os.environ['VLLM_USE_V1'] = '0'
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
from argparse import ArgumentParser
import gradio as gr
import gradio.processing_utils as processing_utils
import numpy as np
import soundfile as sf
from gradio_client import utils as client_utils
import torch
# Transformers and Qwen Omni-Utils imports for local inference
from transformers import Qwen3OmniMoeForConditionalGeneration, Qwen3OmniMoeProcessor
from qwen_omni_utils import process_mm_info
def _load_model_processor(args):
"""
Loads the Qwen3-Omni model and processor from Hugging Face using the transformers library.
"""
print(f"Loading model from: {args.checkpoint_path}")
# Model loading configuration
device_map = "cuda" if torch.cuda.is_available() and not args.cpu_only else "cpu"
model_kwargs = {
"dtype": "auto",
"device_map": device_map,
"trust_remote_code": True,
}
# Use flash attention 2 if available and enabled for better performance
if args.flash_attn2 and torch.cuda.is_available():
model_kwargs["attn_implementation"] = "flash_attention_2"
# Load the model and processor
model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(
args.checkpoint_path,
**model_kwargs
)
processor = Qwen3OmniMoeProcessor.from_pretrained(args.checkpoint_path)
print("Model and processor loaded successfully.")
return model, processor
def _launch_demo(args, model, processor):
# Voice settings updated for the transformers model
VOICE_OPTIONS = {
"Ethan (Male)": "Ethan",
"Chelsie (Female)": "Chelsie",
"Aiden (Male)": "Aiden",
}
DEFAULT_VOICE = 'Ethan (Male)'
default_system_prompt = ''
def to_mp4(path):
"""Converts webm video files to mp4 for compatibility."""
import subprocess
if path and path.endswith(".webm"):
mp4_path = path.replace(".webm", ".mp4")
try:
subprocess.run([
"ffmpeg", "-y", "-i", path, "-c:v", "libx264",
"-preset", "ultrafast", "-pix_fmt", "yuv420p",
"-c:a", "aac", "-b:a", "128k", mp4_path
], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
return mp4_path
except (subprocess.CalledProcessError, FileNotFoundError):
print("ffmpeg conversion failed. Returning original path.")
return path
return path
def format_conversation_for_transformers(history: list, system_prompt: str):
"""
Formats the Gradio chat history into the conversation format required
by the Qwen3-Omni processor.
"""
conversation = []
if system_prompt:
conversation.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]})
# Group consecutive user messages
current_user_content = []
for item in history:
role = item['role']
content = item['content']
if role == "user":
if isinstance(content, str) and content:
current_user_content.append({"type": "text", "text": content})
elif isinstance(content, tuple) and content[0]:
file_path = content[0]
mime_type = client_utils.get_mimetype(file_path)
if mime_type.startswith("image"):
current_user_content.append({"type": "image", "image": file_path})
elif mime_type.startswith("video"):
current_user_content.append({"type": "video", "video": to_mp4(file_path)})
elif mime_type.startswith("audio"):
current_user_content.append({"type": "audio", "audio": file_path})
elif role == "assistant":
if current_user_content:
conversation.append({"role": "user", "content": current_user_content})
current_user_content = []
if isinstance(content, str) and content:
conversation.append({"role": "assistant", "content": [{"type": "text", "text": content}]})
if current_user_content:
conversation.append({"role": "user", "content": current_user_content})
return conversation
@spaces.GPU
def predict(conversation, voice_choice, temperature, top_p, top_k, return_audio, enable_thinking):
"""
Runs local inference using the loaded transformers model.
"""
speaker = VOICE_OPTIONS[voice_choice]
use_audio_in_video = True # Consistently process audio from video files
text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
audios, images, videos = process_mm_info(conversation, use_audio_in_video=use_audio_in_video)
inputs = processor(text=text,
audio=audios,
images=images,
videos=videos,
return_tensors="pt",
padding=True,
use_audio_in_video=use_audio_in_video)
inputs = inputs.to(model.device).to(model.dtype)
gen_kwargs = {
"speaker": speaker,
"thinker_return_dict_in_generate": True,
"use_audio_in_video": use_audio_in_video,
"return_audio": return_audio,
"temperature": float(temperature),
"top_p": float(top_p),
"top_k": int(top_k),
"max_new_tokens": 8192,
}
text_ids, audio_tensor = model.generate(**inputs, **gen_kwargs)
response_text = processor.batch_decode(
text_ids.sequences[:, inputs["input_ids"].shape[1]:],
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0]
yield {"type": "text", "data": response_text}
if audio_tensor is not None and return_audio:
audio_np = audio_tensor.reshape(-1).detach().cpu().numpy()
with io.BytesIO() as wav_io:
sf.write(wav_io, audio_np, samplerate=24000, format="WAV")
wav_bytes = wav_io.getvalue()
audio_path = processing_utils.save_bytes_to_cache(
wav_bytes, "audio.wav", cache_dir=demo.GRADIO_CACHE
)
yield {"type": "audio", "data": audio_path}
@spaces.GPU
def chat_predict(text, audio, image, video, history, system_prompt, voice_choice, temperature, top_p, top_k,
return_audio=False, enable_thinking=False):
if audio:
history.append({"role": "user", "content": (audio,)})
if image:
history.append({"role": "user", "content": (image,)})
if video:
history.append({"role": "user", "content": (video,)})
if text:
history.append({"role": "user", "content": text})
yield gr.Textbox(value=None), gr.Audio(value=None), gr.Image(value=None), gr.Video(value=None), history
conversation = format_conversation_for_transformers(history, system_prompt)
history.append({"role": "assistant", "content": ""})
final_text = ""
final_audio_path = None
for chunk in predict(conversation, voice_choice, temperature, top_p, top_k, return_audio, enable_thinking):
if chunk["type"] == "text":
final_text = chunk["data"]
history[-1]["content"] = final_text
yield gr.skip(), gr.skip(), gr.skip(), gr.skip(), history
elif chunk["type"] == "audio":
final_audio_path = chunk["data"]
if final_audio_path:
history.append({"role": "assistant", "content": gr.Audio(final_audio_path, autoplay=True)})
yield gr.skip(), gr.skip(), gr.skip(), gr.skip(), history
with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Source Sans Pro"), "Arial", "sans-serif"]),
css=".gradio-container {max-width: none !important;}") as demo:
gr.Markdown("# Qwen3-Omni Demo (Local Transformers on HF Spaces)")
gr.Markdown(
"**Instructions**: Interact with the locally running model through text, audio, images, or video.")
with gr.Row(equal_height=False):
with gr.Column(scale=1):
gr.Markdown("### ⚙️ Parameters")
system_prompt_textbox = gr.Textbox(label="System Prompt", value=default_system_prompt, lines=4,
max_lines=8)
voice_choice = gr.Dropdown(label="Voice Choice", choices=list(VOICE_OPTIONS.keys()), value=DEFAULT_VOICE,
visible=True)
return_audio = gr.Checkbox(
label="Return Audio",
value=True,
interactive=True,
)
enable_thinking = gr.Checkbox(
label="Enable Thinking",
value=False,
interactive=True,
info="Note: Requires loading the 'Thinking' model variant."
)
temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=2.0, value=0.6, step=0.1)
top_p = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, value=0.95, step=0.05)
top_k = gr.Slider(label="Top K", minimum=1, maximum=100, value=20, step=1)
with gr.Column(scale=3):
chatbot = gr.Chatbot(label="Chat History", height=650,
layout="panel", bubble_full_width=False,
render=False,
type="messages")
chatbot.render()
with gr.Accordion("📎 Click to upload multimodal files", open=False):
with gr.Row():
audio_input = gr.Audio(sources=["upload", 'microphone'], type="filepath", label="Audio")
image_input = gr.Image(sources=["upload", 'webcam'], type="filepath", label="Image")
video_input = gr.Video(sources=["upload", 'webcam'], label="Video")
with gr.Row():
text_input = gr.Textbox(show_label=False,
placeholder="Enter text or upload files and press Submit...",
scale=7)
submit_btn = gr.Button("Submit", variant="primary", scale=1)
clear_btn = gr.Button("Clear", scale=1)
def clear_history():
return [], None, None, None, None
submit_event = gr.on(
triggers=[submit_btn.click, text_input.submit],
fn=chat_predict,
inputs=[text_input, audio_input, image_input, video_input, chatbot, system_prompt_textbox,
voice_choice, temperature, top_p, top_k, return_audio, enable_thinking],
outputs=[text_input, audio_input, image_input, video_input, chatbot]
)
clear_btn.click(fn=clear_history,
outputs=[chatbot, text_input, audio_input, image_input, video_input])
demo.queue().launch(share=args.share,
inbrowser=args.inbrowser,
server_port=args.server_port,
server_name=args.server_name)
DEFAULT_CKPT_PATH = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
def _get_args():
parser = ArgumentParser()
parser.add_argument('-c',
'--checkpoint-path',
type=str,
default=DEFAULT_CKPT_PATH,
help='Hugging Face model checkpoint name or path, default to %(default)r')
parser.add_argument('--cpu-only', action='store_true', help='Run demo with CPU only')
parser.add_argument('--flash-attn2',
action='store_true',
default=True,
help='Enable flash_attention_2 when loading the model.')
parser.add_argument('--share',
action='store_true',
default=False,
help='Create a publicly shareable link for the interface.')
parser.add_argument('--inbrowser',
action='store_true',
default=False,
help='Automatically launch the interface in a new tab on the default browser.')
parser.add_argument('--server-port', type=int, default=7860, help='Demo server port.')
parser.add_argument('--server-name', type=str, default=None, help='Demo server name.') # Set to None for Spaces
args = parser.parse_args([]) # Use empty list for args when running in Spaces
return args
if __name__ == "__main__":
args = _get_args()
model, processor = _load_model_processor(args)
_launch_demo(args, model, processor)