Spaces:
Paused
Paused
| import asyncio | |
| import contextlib | |
| import typing | |
| from typing import Callable, Dict, Union | |
| import aiohttp | |
| import aiohttp.client_exceptions | |
| import aiohttp.http_exceptions | |
| import httpx | |
| from aiohttp.client import ClientResponse, ClientSession | |
| from litellm._logging import verbose_logger | |
| AIOHTTP_EXC_MAP: Dict = { | |
| # Order matters here, most specific exception first | |
| # Timeout related exceptions | |
| aiohttp.ServerTimeoutError: httpx.TimeoutException, | |
| aiohttp.ConnectionTimeoutError: httpx.ConnectTimeout, | |
| aiohttp.SocketTimeoutError: httpx.ReadTimeout, | |
| # Proxy related exceptions | |
| aiohttp.ClientProxyConnectionError: httpx.ProxyError, | |
| # SSL related exceptions | |
| aiohttp.ClientConnectorCertificateError: httpx.ProtocolError, | |
| aiohttp.ClientSSLError: httpx.ProtocolError, | |
| aiohttp.ServerFingerprintMismatch: httpx.ProtocolError, | |
| # Network related exceptions | |
| aiohttp.ClientConnectorError: httpx.ConnectError, | |
| aiohttp.ClientOSError: httpx.ConnectError, | |
| aiohttp.ClientPayloadError: httpx.ReadError, | |
| # Connection disconnection exceptions | |
| aiohttp.ServerDisconnectedError: httpx.ReadError, | |
| # Response related exceptions | |
| aiohttp.ClientConnectionError: httpx.NetworkError, | |
| aiohttp.ClientPayloadError: httpx.ReadError, | |
| aiohttp.ContentTypeError: httpx.ReadError, | |
| aiohttp.TooManyRedirects: httpx.TooManyRedirects, | |
| # URL related exceptions | |
| aiohttp.InvalidURL: httpx.InvalidURL, | |
| # Base exceptions | |
| aiohttp.ClientError: httpx.RequestError, | |
| } | |
| # Add client_exceptions module exceptions | |
| try: | |
| import aiohttp.client_exceptions | |
| AIOHTTP_EXC_MAP[aiohttp.client_exceptions.ClientPayloadError] = httpx.ReadError | |
| except ImportError: | |
| pass | |
| def map_aiohttp_exceptions() -> typing.Iterator[None]: | |
| try: | |
| yield | |
| except Exception as exc: | |
| mapped_exc = None | |
| for from_exc, to_exc in AIOHTTP_EXC_MAP.items(): | |
| if not isinstance(exc, from_exc): # type: ignore | |
| continue | |
| if mapped_exc is None or issubclass(to_exc, mapped_exc): | |
| mapped_exc = to_exc | |
| if mapped_exc is None: # pragma: no cover | |
| raise | |
| message = str(exc) | |
| raise mapped_exc(message) from exc | |
| class AiohttpResponseStream(httpx.AsyncByteStream): | |
| CHUNK_SIZE = 1024 * 16 | |
| def __init__(self, aiohttp_response: ClientResponse) -> None: | |
| self._aiohttp_response = aiohttp_response | |
| async def __aiter__(self) -> typing.AsyncIterator[bytes]: | |
| try: | |
| async for chunk in self._aiohttp_response.content.iter_chunked( | |
| self.CHUNK_SIZE | |
| ): | |
| yield chunk | |
| except ( | |
| aiohttp.ClientPayloadError, | |
| aiohttp.client_exceptions.ClientPayloadError, | |
| ) as e: | |
| # Handle incomplete transfers more gracefully | |
| # Log the error but don't re-raise if we've already yielded some data | |
| verbose_logger.debug(f"Transfer incomplete, but continuing: {e}") | |
| # If the error is due to incomplete transfer encoding, we can still | |
| # return what we've received so far, similar to how httpx handles it | |
| return | |
| except aiohttp.http_exceptions.TransferEncodingError as e: | |
| # Handle transfer encoding errors gracefully | |
| verbose_logger.debug(f"Transfer encoding error, but continuing: {e}") | |
| return | |
| except Exception: | |
| # For other exceptions, use the normal mapping | |
| with map_aiohttp_exceptions(): | |
| raise | |
| async def aclose(self) -> None: | |
| with map_aiohttp_exceptions(): | |
| await self._aiohttp_response.__aexit__(None, None, None) | |
| class AiohttpTransport(httpx.AsyncBaseTransport): | |
| def __init__( | |
| self, client: Union[ClientSession, Callable[[], ClientSession]] | |
| ) -> None: | |
| self.client = client | |
| async def aclose(self) -> None: | |
| if isinstance(self.client, ClientSession): | |
| await self.client.close() | |
| class LiteLLMAiohttpTransport(AiohttpTransport): | |
| """ | |
| LiteLLM wrapper around AiohttpTransport to handle %-encodings in URLs | |
| and event loop lifecycle issues in CI/CD environments | |
| Credit to: https://github.com/karpetrosyan/httpx-aiohttp for this implementation | |
| """ | |
| def __init__(self, client: Union[ClientSession, Callable[[], ClientSession]]): | |
| self.client = client | |
| super().__init__(client=client) | |
| # Store the client factory for recreating sessions when needed | |
| if callable(client): | |
| self._client_factory = client | |
| def _get_valid_client_session(self) -> ClientSession: | |
| """ | |
| Helper to get a valid ClientSession for the current event loop. | |
| This handles the case where the session was created in a different | |
| event loop that may have been closed (common in CI/CD environments). | |
| """ | |
| from aiohttp.client import ClientSession | |
| # If we don't have a client or it's not a ClientSession, create one | |
| if not isinstance(self.client, ClientSession): | |
| if hasattr(self, "_client_factory") and callable(self._client_factory): | |
| self.client = self._client_factory() | |
| else: | |
| self.client = ClientSession() | |
| return self.client | |
| # Check if the existing session is still valid for the current event loop | |
| try: | |
| session_loop = getattr(self.client, "_loop", None) | |
| current_loop = asyncio.get_running_loop() | |
| # If session is from a different or closed loop, recreate it | |
| if ( | |
| session_loop is None | |
| or session_loop != current_loop | |
| or session_loop.is_closed() | |
| ): | |
| # Clean up the old session | |
| try: | |
| # Note: not awaiting close() here as it might be from a different loop | |
| # The session will be garbage collected | |
| pass | |
| except Exception as e: | |
| verbose_logger.debug(f"Error closing old session: {e}") | |
| pass | |
| # Create a new session in the current event loop | |
| if hasattr(self, "_client_factory") and callable(self._client_factory): | |
| self.client = self._client_factory() | |
| else: | |
| self.client = ClientSession() | |
| except (RuntimeError, AttributeError): | |
| # If we can't check the loop or session is invalid, recreate it | |
| if hasattr(self, "_client_factory") and callable(self._client_factory): | |
| self.client = self._client_factory() | |
| else: | |
| self.client = ClientSession() | |
| return self.client | |
| async def handle_async_request( | |
| self, | |
| request: httpx.Request, | |
| ) -> httpx.Response: | |
| from aiohttp import ClientTimeout | |
| from yarl import URL as YarlURL | |
| timeout = request.extensions.get("timeout", {}) | |
| sni_hostname = request.extensions.get("sni_hostname") | |
| # Use helper to ensure we have a valid session for the current event loop | |
| client_session = self._get_valid_client_session() | |
| with map_aiohttp_exceptions(): | |
| try: | |
| data = request.content | |
| except httpx.RequestNotRead: | |
| data = request.stream # type: ignore | |
| request.headers.pop("transfer-encoding", None) # handled by aiohttp | |
| response = await client_session.request( | |
| method=request.method, | |
| url=YarlURL(str(request.url), encoded=True), | |
| headers=request.headers, | |
| data=data, | |
| allow_redirects=False, | |
| auto_decompress=False, | |
| timeout=ClientTimeout( | |
| sock_connect=timeout.get("connect"), | |
| sock_read=timeout.get("read"), | |
| connect=timeout.get("pool"), | |
| ), | |
| server_hostname=sni_hostname, | |
| ).__aenter__() | |
| return httpx.Response( | |
| status_code=response.status, | |
| headers=response.headers, | |
| content=AiohttpResponseStream(response), | |
| request=request, | |
| ) | |