|
|
from typing import Dict, List, Any |
|
|
import soundfile as sf |
|
|
from transformers import Qwen3OmniMoeForConditionalGeneration, Qwen3OmniMoeProcessor |
|
|
from qwen_omni_utils import process_mm_info |
|
|
|
|
|
class EndpointHandler(): |
|
|
def __init__(self, path="./"): |
|
|
self.model = Qwen3OmniMoeForConditionalGeneration.from_pretrained( |
|
|
path, |
|
|
dtype="auto", |
|
|
device_map="auto", |
|
|
) |
|
|
self.processor = Qwen3OmniMoeProcessor.from_pretrained(path) |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
messages = data.get("messages", []) |
|
|
use_audio_in_video = data.get("use_audio_in_video", True) |
|
|
speaker = data.get("speaker", "Ethan") |
|
|
|
|
|
text = self.processor.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True, |
|
|
) |
|
|
audios, images, videos = process_mm_info(messages, use_audio_in_video=use_audio_in_video) |
|
|
inputs = self.processor( |
|
|
text=text, |
|
|
audio=audios, |
|
|
images=images, |
|
|
videos=videos, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
use_audio_in_video=use_audio_in_video |
|
|
) |
|
|
inputs = inputs.to(self.model.device).to(self.model.dtype) |
|
|
|
|
|
text_ids, audio = self.model.generate( |
|
|
**inputs, |
|
|
speaker=speaker, |
|
|
thinker_return_dict_in_generate=True, |
|
|
use_audio_in_video=use_audio_in_video |
|
|
) |
|
|
text_output = self.processor.batch_decode( |
|
|
text_ids.sequences[:, inputs["input_ids"].shape[1]:], |
|
|
skip_special_tokens=True, |
|
|
clean_up_tokenization_spaces=False |
|
|
) |
|
|
result = {"generated_text": text_output} |
|
|
if audio is not None: |
|
|
|
|
|
sf.write("output.wav", audio.reshape(-1).detach().cpu().numpy(), samplerate=24000) |
|
|
result["audio_path"] = "output.wav" |
|
|
return [result] |