rosassebastian2003 commited on
Commit
f2527c6
1 Parent(s): 8641305
Files changed (1) hide show
  1. handler.py +41 -24
handler.py CHANGED
@@ -8,92 +8,107 @@ import os
8
  import tempfile
9
  import numpy as np
10
 
 
 
 
11
  class EndpointHandler():
12
  def __init__(self, path=""):
 
 
13
  model_kwargs = {
14
- "device_map": "auto",
15
  "torch_dtype": torch.bfloat16 if torch.cuda.is_available() else None,
16
- "enable_audio_output": True
17
  }
18
 
 
19
  self.pipeline = pipeline(
20
  task="text-generation",
21
- model=path,
22
- **model_kwargs
23
  )
24
 
 
25
  self.system_prompt = (
26
  "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, "
27
  "capable of perceiving auditory and visual inputs, as well as generating text and speech."
28
  )
29
 
30
- self.sampling_rate = self.pipeline.model.config.sampling_rate
 
 
31
 
32
  def _handle_audio_input(self, data: Dict[str, Any]) -> str:
 
33
  audio_data_base64 = data.get("audio_data")
34
  if not audio_data_base64:
35
  return None
36
-
37
  temp_file_path = None
38
  try:
39
  audio_bytes = base64.b64decode(audio_data_base64)
40
-
41
  temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
42
  temp_file.write(audio_bytes)
43
  temp_file.close()
44
  temp_file_path = temp_file.name
45
-
46
  return temp_file_path
47
  except Exception as e:
48
-
49
  if temp_file_path and os.path.exists(temp_file_path):
50
- os.remove(temp_file_path)
51
  raise ValueError(f"Error al decodificar y guardar el audio Base64: {e}")
52
 
53
  def _handle_audio_output(self, generated_audio: torch.Tensor, sampling_rate: int) -> str:
 
54
  audio_array = generated_audio.cpu().numpy().squeeze()
55
  if audio_array.dtype!= np.float32:
56
- audio_array = audio_array.astype(np.float32)
57
 
58
- encoded_audio = None
59
  with io.BytesIO() as buffer:
 
60
  wavfile.write(buffer, rate=sampling_rate, data=audio_array)
61
  buffer.seek(0)
62
-
63
- encoded_audio = base64.b64encode(buffer.read()).decode('utf-8')
64
 
65
- return encoded_audio
 
 
66
 
67
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
68
  prompt = data.get("inputs")
69
  if not prompt:
70
  raise ValueError("El campo 'inputs' (prompt de texto) es obligatorio.")
71
-
72
  generation_kwargs = data.get("parameters", {})
73
  audio_file_path = None
74
-
75
  try:
 
76
  audio_file_path = self._handle_audio_input(data)
77
-
 
78
  inputs_list = [prompt]
79
  if audio_file_path:
80
  inputs_list.append(audio_file_path)
81
 
 
82
  generation_kwargs.update({
83
- "system_prompt": self.system_prompt,
84
- "return_audio": True,
85
  "max_new_tokens": generation_kwargs.get("max_new_tokens", 512),
86
  })
87
 
 
88
  raw_output = self.pipeline(inputs_list, **generation_kwargs)
89
-
 
90
  response = raw_output
91
-
92
  final_response = {
93
  "generated_text": response.get("generated_text"),
94
  "audio_output": None
95
  }
96
-
 
97
  if "audio_array" in response:
98
  encoded_audio = self._handle_audio_output(response["audio_array"], self.sampling_rate)
99
  final_response["audio_output"] = encoded_audio
@@ -101,8 +116,10 @@ class EndpointHandler():
101
  return [final_response]
102
 
103
  except Exception as e:
 
104
  return [{"error": str(e)}]
105
 
106
  finally:
 
107
  if audio_file_path and os.path.exists(audio_file_path):
108
- os.remove(audio_file_path)
 
8
  import tempfile
9
  import numpy as np
10
 
11
+ # Nombre del modelo (usado como fallback si 'path' no se proporciona)
12
+ MODEL_NAME = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
13
+
14
  class EndpointHandler():
15
  def __init__(self, path=""):
16
+
17
+ # 1. Configuraciones cr铆ticas para la carga del modelo MoE y la funcionalidad de voz
18
  model_kwargs = {
19
+ "device_map": "auto", # Optimizaci贸n para la distribuci贸n de pesos en GPU [1]
20
  "torch_dtype": torch.bfloat16 if torch.cuda.is_available() else None,
21
+ "enable_audio_output": True # Clave esencial para cargar el componente Talker (generador de voz) [4]
22
  }
23
 
24
+ # 2. Carga del pipeline gen茅rico de generaci贸n de texto (el wrapper para LLM multimodales) [3]
25
  self.pipeline = pipeline(
26
  task="text-generation",
27
+ model=path or MODEL_NAME,
28
+ **model_kwargs # Inyecci贸n de los par谩metros espec铆ficos de Qwen3
29
  )
30
 
31
+ # 3. System prompt obligatorio para Qwen3-Omni para generar audio natural [4]
32
  self.system_prompt = (
33
  "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, "
34
  "capable of perceiving auditory and visual inputs, as well as generating text and speech."
35
  )
36
 
37
+ # 4. Tasa de muestreo del modelo (necesaria para la serializaci贸n de audio en __call__)
38
+ self.sampling_rate = getattr(self.pipeline.model.config, 'sampling_rate', 24000)
39
+
40
 
41
  def _handle_audio_input(self, data: Dict[str, Any]) -> str:
42
+ """ Decodifica la entrada de audio Base64 y la guarda temporalmente como un archivo WAV. """
43
  audio_data_base64 = data.get("audio_data")
44
  if not audio_data_base64:
45
  return None
46
+
47
  temp_file_path = None
48
  try:
49
  audio_bytes = base64.b64decode(audio_data_base64)
50
+ # Guardar en un archivo temporal para que el pipeline lo pueda procesar [5]
51
  temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
52
  temp_file.write(audio_bytes)
53
  temp_file.close()
54
  temp_file_path = temp_file.name
 
55
  return temp_file_path
56
  except Exception as e:
 
57
  if temp_file_path and os.path.exists(temp_file_path):
58
+ os.remove(temp_file_path)
59
  raise ValueError(f"Error al decodificar y guardar el audio Base64: {e}")
60
 
61
  def _handle_audio_output(self, generated_audio: torch.Tensor, sampling_rate: int) -> str:
62
+ """ Convierte el tensor de audio de salida a un buffer WAV y lo codifica en Base64. """
63
  audio_array = generated_audio.cpu().numpy().squeeze()
64
  if audio_array.dtype!= np.float32:
65
+ audio_array = audio_array.astype(np.float32)
66
 
 
67
  with io.BytesIO() as buffer:
68
+ # Escribir el array como WAV [2]
69
  wavfile.write(buffer, rate=sampling_rate, data=audio_array)
70
  buffer.seek(0)
 
 
71
 
72
+ # Codificar a Base64 para la respuesta JSON
73
+ encoded_audio = base64.b64encode(buffer.read()).decode('utf-8')
74
+ return encoded_audio
75
 
76
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
77
  prompt = data.get("inputs")
78
  if not prompt:
79
  raise ValueError("El campo 'inputs' (prompt de texto) es obligatorio.")
80
+
81
  generation_kwargs = data.get("parameters", {})
82
  audio_file_path = None
83
+
84
  try:
85
+ # 1. Manejo de I/O de audio (Base64 -> Archivo Temporal)
86
  audio_file_path = self._handle_audio_input(data)
87
+
88
+ # 2. El pipeline espera una lista de entradas multimodales (Texto o Audio)
89
  inputs_list = [prompt]
90
  if audio_file_path:
91
  inputs_list.append(audio_file_path)
92
 
93
+ # 3. Configuraci贸n de generaci贸n
94
  generation_kwargs.update({
95
+ "system_prompt": self.system_prompt, # Requerido para la calidad de la voz [4]
96
+ "return_audio": True, # Solicitamos que la salida contenga el tensor de audio [4]
97
  "max_new_tokens": generation_kwargs.get("max_new_tokens", 512),
98
  })
99
 
100
+ # 4. Ejecutar el pipeline
101
  raw_output = self.pipeline(inputs_list, **generation_kwargs)
102
+
103
+ # El pipeline devuelve una lista de diccionarios, extraemos el primer resultado
104
  response = raw_output
105
+
106
  final_response = {
107
  "generated_text": response.get("generated_text"),
108
  "audio_output": None
109
  }
110
+
111
+ # 5. Post-procesamiento (Tensor -> Base64-WAV)
112
  if "audio_array" in response:
113
  encoded_audio = self._handle_audio_output(response["audio_array"], self.sampling_rate)
114
  final_response["audio_output"] = encoded_audio
 
116
  return [final_response]
117
 
118
  except Exception as e:
119
+ # Manejo de errores
120
  return [{"error": str(e)}]
121
 
122
  finally:
123
+ # 6. Limpieza de archivos temporales (Mantenimiento cr铆tico)
124
  if audio_file_path and os.path.exists(audio_file_path):
125
+ os.remove(audio_file_path)