Spaces:
Running
Running
| """ | |
| AmaniQuery VibeVoice Service - FastAPI wrapper for TTS | |
| Standalone HuggingFace Space for voice synthesis | |
| """ | |
| import os | |
| import io | |
| import wave | |
| import logging | |
| from typing import Optional | |
| from fastapi import FastAPI, HTTPException, Header, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse, JSONResponse | |
| from pydantic import BaseModel, Field | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # FastAPI app | |
| app = FastAPI( | |
| title="AmaniQuery VibeVoice", | |
| description="Text-to-Speech service for AmaniQuery using Microsoft VibeVoice", | |
| version="1.0.0", | |
| ) | |
| # CORS configuration | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # In production, restrict to specific origins | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global model instance (lazy loaded) | |
| _tts_model = None | |
| _processor = None | |
| class SpeakRequest(BaseModel): | |
| """Request model for text-to-speech""" | |
| text: str = Field(..., description="Text to convert to speech", max_length=5000) | |
| voice: str = Field(default="Wayne", description="Voice preset to use") | |
| cfg_scale: float = Field(default=1.5, ge=1.0, le=3.0, description="Classifier-free guidance scale") | |
| class VoiceInfo(BaseModel): | |
| """Voice preset information""" | |
| name: str | |
| description: str | |
| # Available voice presets | |
| VOICE_PRESETS = [ | |
| VoiceInfo(name="Wayne", description="Male, American English, Calm"), | |
| VoiceInfo(name="Angela", description="Female, American English, Professional"), | |
| VoiceInfo(name="Aria", description="Female, American English, Warm"), | |
| VoiceInfo(name="Davis", description="Male, American English, Confident"), | |
| ] | |
| def get_tts_model(): | |
| """Lazy load the TTS model""" | |
| global _tts_model, _processor | |
| if _tts_model is None: | |
| logger.info("Loading VibeVoice model...") | |
| try: | |
| import torch | |
| from vibevoice.modular import ( | |
| VibeVoiceStreamingForConditionalGenerationInference, | |
| VibeVoiceStreamingConfig, | |
| ) | |
| from vibevoice.processor import VibeVoiceStreamingProcessor | |
| device = os.getenv("VIBEVOICE_DEVICE", "cpu") | |
| if device == "auto": | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| config = VibeVoiceStreamingConfig( | |
| model_path="microsoft/VibeVoice-Realtime-0.5B", | |
| device=device, | |
| ) | |
| _tts_model = VibeVoiceStreamingForConditionalGenerationInference(config) | |
| _processor = VibeVoiceStreamingProcessor() | |
| logger.info(f"VibeVoice model loaded on {device}") | |
| except Exception as e: | |
| logger.error(f"Failed to load VibeVoice model: {e}") | |
| raise | |
| return _tts_model, _processor | |
| def validate_jwt(authorization: Optional[str] = None) -> bool: | |
| """Validate JWT token for cross-service auth (optional)""" | |
| jwt_secret = os.getenv("JWT_SECRET") | |
| if not jwt_secret: | |
| # No JWT configured, allow all requests | |
| return True | |
| if not authorization or not authorization.startswith("Bearer "): | |
| return False | |
| try: | |
| import jwt | |
| token = authorization.replace("Bearer ", "") | |
| jwt.decode(token, jwt_secret, algorithms=["HS256"]) | |
| return True | |
| except Exception: | |
| return False | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return {"status": "healthy", "service": "vibevoice"} | |
| async def list_voices(): | |
| """List available voice presets""" | |
| return {"voices": [v.dict() for v in VOICE_PRESETS]} | |
| async def speak( | |
| request: SpeakRequest, | |
| authorization: Optional[str] = Header(None), | |
| ): | |
| """Convert text to speech""" | |
| # Optional JWT validation | |
| if os.getenv("JWT_SECRET") and not validate_jwt(authorization): | |
| raise HTTPException(status_code=401, detail="Invalid or missing authentication") | |
| try: | |
| model, processor = get_tts_model() | |
| # Generate audio | |
| logger.info(f"Generating speech for: {request.text[:50]}...") | |
| # Process text and generate audio | |
| audio_data = model.generate( | |
| text=request.text, | |
| voice=request.voice, | |
| cfg_scale=request.cfg_scale, | |
| ) | |
| # Convert to WAV format | |
| audio_buffer = io.BytesIO() | |
| with wave.open(audio_buffer, 'wb') as wav_file: | |
| wav_file.setnchannels(1) | |
| wav_file.setsampwidth(2) # 16-bit | |
| wav_file.setframerate(24000) # Sample rate | |
| wav_file.writeframes(audio_data.tobytes()) | |
| audio_buffer.seek(0) | |
| return StreamingResponse( | |
| audio_buffer, | |
| media_type="audio/wav", | |
| headers={"Content-Disposition": "attachment; filename=speech.wav"} | |
| ) | |
| except Exception as e: | |
| logger.error(f"TTS generation failed: {e}") | |
| raise HTTPException(status_code=500, detail=f"Speech generation failed: {str(e)}") | |
| async def voice_chat( | |
| request: SpeakRequest, | |
| authorization: Optional[str] = Header(None), | |
| ): | |
| """Generate conversational voice response (same as speak for now)""" | |
| return await speak(request, authorization) | |
| async def root(): | |
| """Root endpoint with service info""" | |
| return { | |
| "service": "AmaniQuery VibeVoice", | |
| "version": "1.0.0", | |
| "endpoints": { | |
| "health": "/health", | |
| "speak": "/api/v1/voice/speak", | |
| "voices": "/api/v1/voice/voices", | |
| } | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.getenv("PORT", 7860)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |