Spaces:
Paused
Paused
| import asyncio | |
| import httpx | |
| import json | |
| import pytest | |
| import sys | |
| from typing import Any, Dict, List | |
| from unittest.mock import MagicMock, Mock, patch, ANY | |
| import os | |
| sys.path.insert( | |
| 0, os.path.abspath("../..") | |
| ) # Adds the parent directory to the system path | |
| import litellm | |
| from litellm.exceptions import BadRequestError | |
| from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler | |
| from litellm.utils import CustomStreamWrapper | |
| from base_llm_unit_tests import BaseLLMChatTest, BaseAnthropicChatTest | |
| try: | |
| import databricks.sdk | |
| databricks_sdk_installed = True | |
| except ImportError: | |
| databricks_sdk_installed = False | |
| def mock_chat_response() -> Dict[str, Any]: | |
| return { | |
| "id": "chatcmpl_3f78f09a-489c-4b8d-a587-f162c7497891", | |
| "object": "chat.completion", | |
| "created": 1726285449, | |
| "model": "dbrx-instruct-071224", | |
| "choices": [ | |
| { | |
| "index": 0, | |
| "message": { | |
| "role": "assistant", | |
| "content": "Hello! I'm an AI assistant. I'm doing well. How can I help?", | |
| "function_call": None, | |
| "tool_calls": None, | |
| }, | |
| "finish_reason": "stop", | |
| } | |
| ], | |
| "usage": { | |
| "prompt_tokens": 230, | |
| "completion_tokens": 38, | |
| "completion_tokens_details": None, | |
| "total_tokens": 268, | |
| "prompt_tokens_details": None, | |
| }, | |
| "system_fingerprint": None, | |
| } | |
| def mock_chat_streaming_response_chunks() -> List[str]: | |
| return [ | |
| json.dumps( | |
| { | |
| "id": "chatcmpl_8a7075d1-956e-4960-b3a6-892cd4649ff3", | |
| "object": "chat.completion.chunk", | |
| "created": 1726469651, | |
| "model": "dbrx-instruct-071224", | |
| "choices": [ | |
| { | |
| "index": 0, | |
| "delta": {"role": "assistant", "content": "Hello"}, | |
| "finish_reason": None, | |
| "logprobs": None, | |
| } | |
| ], | |
| "usage": { | |
| "prompt_tokens": 230, | |
| "completion_tokens": 1, | |
| "total_tokens": 231, | |
| }, | |
| } | |
| ), | |
| json.dumps( | |
| { | |
| "id": "chatcmpl_8a7075d1-956e-4960-b3a6-892cd4649ff3", | |
| "object": "chat.completion.chunk", | |
| "created": 1726469651, | |
| "model": "dbrx-instruct-071224", | |
| "choices": [ | |
| { | |
| "index": 0, | |
| "delta": {"content": " world"}, | |
| "finish_reason": None, | |
| "logprobs": None, | |
| } | |
| ], | |
| "usage": { | |
| "prompt_tokens": 230, | |
| "completion_tokens": 1, | |
| "total_tokens": 231, | |
| }, | |
| } | |
| ), | |
| json.dumps( | |
| { | |
| "id": "chatcmpl_8a7075d1-956e-4960-b3a6-892cd4649ff3", | |
| "object": "chat.completion.chunk", | |
| "created": 1726469651, | |
| "model": "dbrx-instruct-071224", | |
| "choices": [ | |
| { | |
| "index": 0, | |
| "delta": {"content": "!"}, | |
| "finish_reason": "stop", | |
| "logprobs": None, | |
| } | |
| ], | |
| "usage": { | |
| "prompt_tokens": 230, | |
| "completion_tokens": 1, | |
| "total_tokens": 231, | |
| }, | |
| } | |
| ), | |
| ] | |
| def mock_chat_streaming_response_chunks_bytes() -> List[bytes]: | |
| string_chunks = mock_chat_streaming_response_chunks() | |
| bytes_chunks = [chunk.encode("utf-8") + b"\n" for chunk in string_chunks] | |
| # Simulate the end of the stream | |
| bytes_chunks.append(b"") | |
| return bytes_chunks | |
| def mock_http_handler_chat_streaming_response() -> MagicMock: | |
| mock_stream_chunks = mock_chat_streaming_response_chunks() | |
| def mock_iter_lines(): | |
| for chunk in mock_stream_chunks: | |
| for line in chunk.splitlines(): | |
| yield line | |
| mock_response = MagicMock() | |
| mock_response.iter_lines.side_effect = mock_iter_lines | |
| mock_response.status_code = 200 | |
| return mock_response | |
| def mock_http_handler_chat_async_streaming_response() -> MagicMock: | |
| mock_stream_chunks = mock_chat_streaming_response_chunks() | |
| async def mock_iter_lines(): | |
| for chunk in mock_stream_chunks: | |
| for line in chunk.splitlines(): | |
| yield line | |
| mock_response = MagicMock() | |
| mock_response.aiter_lines.return_value = mock_iter_lines() | |
| mock_response.status_code = 200 | |
| return mock_response | |
| def mock_databricks_client_chat_streaming_response() -> MagicMock: | |
| mock_stream_chunks = mock_chat_streaming_response_chunks_bytes() | |
| def mock_read_from_stream(size=-1): | |
| if mock_stream_chunks: | |
| return mock_stream_chunks.pop(0) | |
| return b"" | |
| mock_response = MagicMock() | |
| streaming_response_mock = MagicMock() | |
| streaming_response_iterator_mock = MagicMock() | |
| # Mock the __getitem__("content") method to return the streaming response | |
| mock_response.__getitem__.return_value = streaming_response_mock | |
| # Mock the streaming response __enter__ method to return the streaming response iterator | |
| streaming_response_mock.__enter__.return_value = streaming_response_iterator_mock | |
| streaming_response_iterator_mock.read1.side_effect = mock_read_from_stream | |
| streaming_response_iterator_mock.closed = False | |
| return mock_response | |
| def mock_embedding_response() -> Dict[str, Any]: | |
| return { | |
| "object": "list", | |
| "model": "bge-large-en-v1.5", | |
| "data": [ | |
| { | |
| "index": 0, | |
| "object": "embedding", | |
| "embedding": [ | |
| 0.06768798828125, | |
| -0.01291656494140625, | |
| -0.0501708984375, | |
| 0.0245361328125, | |
| -0.030364990234375, | |
| ], | |
| } | |
| ], | |
| "usage": { | |
| "prompt_tokens": 8, | |
| "total_tokens": 8, | |
| "completion_tokens": 0, | |
| "completion_tokens_details": None, | |
| "prompt_tokens_details": None, | |
| }, | |
| } | |
| def test_throws_if_api_base_or_api_key_not_set_without_databricks_sdk( | |
| monkeypatch, set_base | |
| ): | |
| # Simulate that the databricks SDK is not installed | |
| monkeypatch.setitem(sys.modules, "databricks.sdk", None) | |
| err_msg = ["the Databricks base URL and API key are not set", "Missing API Key"] | |
| if set_base: | |
| monkeypatch.setenv( | |
| "DATABRICKS_API_BASE", | |
| "https://my.workspace.cloud.databricks.com/serving-endpoints", | |
| ) | |
| monkeypatch.delenv( | |
| "DATABRICKS_API_KEY", | |
| ) | |
| else: | |
| monkeypatch.setenv("DATABRICKS_API_KEY", "dapimykey") | |
| monkeypatch.delenv( | |
| "DATABRICKS_API_BASE", | |
| ) | |
| with pytest.raises(BadRequestError) as exc: | |
| litellm.completion( | |
| model="databricks/dbrx-instruct-071224", | |
| messages=[{"role": "user", "content": "How are you?"}], | |
| ) | |
| assert any(msg in str(exc) for msg in err_msg) | |
| with pytest.raises(BadRequestError) as exc: | |
| litellm.embedding( | |
| model="databricks/bge-12312", | |
| input=["Hello", "World"], | |
| ) | |
| assert any(msg in str(exc) for msg in err_msg) | |
| def test_completions_with_sync_http_handler(monkeypatch): | |
| base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints" | |
| api_key = "dapimykey" | |
| monkeypatch.setenv("DATABRICKS_API_BASE", base_url) | |
| monkeypatch.setenv("DATABRICKS_API_KEY", api_key) | |
| sync_handler = HTTPHandler() | |
| mock_response = Mock(spec=httpx.Response) | |
| mock_response.status_code = 200 | |
| mock_response.json.return_value = mock_chat_response() | |
| expected_response_json = { | |
| **mock_chat_response(), | |
| **{ | |
| "model": "databricks/dbrx-instruct-071224", | |
| }, | |
| } | |
| messages = [{"role": "user", "content": "How are you?"}] | |
| with patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post: | |
| response = litellm.completion( | |
| model="databricks/dbrx-instruct-071224", | |
| messages=messages, | |
| client=sync_handler, | |
| temperature=0.5, | |
| extraparam="testpassingextraparam", | |
| ) | |
| assert mock_post.call_args.kwargs["headers"]["Content-Type"] == "application/json" | |
| assert mock_post.call_args.kwargs["headers"]["Authorization"] == f"Bearer {api_key}" | |
| assert mock_post.call_args.kwargs["url"] == f"{base_url}/chat/completions" | |
| assert mock_post.call_args.kwargs["stream"] == False | |
| actual_data = json.loads( | |
| mock_post.call_args.kwargs["data"] | |
| ) # Deserialize the actual data | |
| expected_data = { | |
| "model": "dbrx-instruct-071224", | |
| "messages": messages, | |
| "temperature": 0.5, | |
| "extraparam": "testpassingextraparam", | |
| } | |
| assert actual_data == expected_data, f"Unexpected JSON data: {actual_data}" | |
| def test_completions_with_async_http_handler(monkeypatch): | |
| base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints" | |
| api_key = "dapimykey" | |
| monkeypatch.setenv("DATABRICKS_API_BASE", base_url) | |
| monkeypatch.setenv("DATABRICKS_API_KEY", api_key) | |
| async_handler = AsyncHTTPHandler() | |
| mock_response = Mock(spec=httpx.Response) | |
| mock_response.status_code = 200 | |
| mock_response.json.return_value = mock_chat_response() | |
| expected_response_json = { | |
| **mock_chat_response(), | |
| **{ | |
| "model": "databricks/dbrx-instruct-071224", | |
| }, | |
| } | |
| messages = [{"role": "user", "content": "How are you?"}] | |
| with patch.object( | |
| AsyncHTTPHandler, "post", return_value=mock_response | |
| ) as mock_post: | |
| response = asyncio.run( | |
| litellm.acompletion( | |
| model="databricks/dbrx-instruct-071224", | |
| messages=messages, | |
| client=async_handler, | |
| temperature=0.5, | |
| extraparam="testpassingextraparam", | |
| ) | |
| ) | |
| assert mock_post.call_args.kwargs["headers"]["Content-Type"] == "application/json" | |
| assert mock_post.call_args.kwargs["headers"]["Authorization"] == f"Bearer {api_key}" | |
| assert mock_post.call_args.kwargs["url"] == f"{base_url}/chat/completions" | |
| assert mock_post.call_args.kwargs["stream"] == False | |
| actual_data = json.loads( | |
| mock_post.call_args.kwargs["data"] | |
| ) # Deserialize the actual data | |
| expected_data = { | |
| "model": "dbrx-instruct-071224", | |
| "messages": messages, | |
| "temperature": 0.5, | |
| "extraparam": "testpassingextraparam", | |
| } | |
| assert actual_data == expected_data, f"Unexpected JSON data: {actual_data}" | |
| def test_completions_streaming_with_sync_http_handler(monkeypatch): | |
| base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints" | |
| api_key = "dapimykey" | |
| monkeypatch.setenv("DATABRICKS_API_BASE", base_url) | |
| monkeypatch.setenv("DATABRICKS_API_KEY", api_key) | |
| sync_handler = HTTPHandler() | |
| messages = [{"role": "user", "content": "How are you?"}] | |
| mock_response = mock_http_handler_chat_streaming_response() | |
| with patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post: | |
| response_stream: CustomStreamWrapper = litellm.completion( | |
| model="databricks/dbrx-instruct-071224", | |
| messages=messages, | |
| client=sync_handler, | |
| temperature=0.5, | |
| extraparam="testpassingextraparam", | |
| stream=True, | |
| ) | |
| response = list(response_stream) | |
| assert "dbrx-instruct-071224" in str(response) | |
| assert "chatcmpl" in str(response) | |
| assert len(response) == 4 | |
| assert mock_post.call_args.kwargs["headers"]["Content-Type"] == "application/json" | |
| assert mock_post.call_args.kwargs["headers"]["Authorization"] == f"Bearer {api_key}" | |
| assert mock_post.call_args.kwargs["url"] == f"{base_url}/chat/completions" | |
| assert mock_post.call_args.kwargs["stream"] == True | |
| actual_data = json.loads( | |
| mock_post.call_args.kwargs["data"] | |
| ) # Deserialize the actual data | |
| expected_data = { | |
| "model": "dbrx-instruct-071224", | |
| "messages": messages, | |
| "temperature": 0.5, | |
| "stream": True, | |
| "extraparam": "testpassingextraparam", | |
| } | |
| assert actual_data == expected_data, f"Unexpected JSON data: {actual_data}" | |
| def test_completions_streaming_with_async_http_handler(monkeypatch): | |
| base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints" | |
| api_key = "dapimykey" | |
| monkeypatch.setenv("DATABRICKS_API_BASE", base_url) | |
| monkeypatch.setenv("DATABRICKS_API_KEY", api_key) | |
| async_handler = AsyncHTTPHandler() | |
| messages = [{"role": "user", "content": "How are you?"}] | |
| mock_response = mock_http_handler_chat_async_streaming_response() | |
| with patch.object( | |
| AsyncHTTPHandler, "post", return_value=mock_response | |
| ) as mock_post: | |
| response_stream: CustomStreamWrapper = asyncio.run( | |
| litellm.acompletion( | |
| model="databricks/dbrx-instruct-071224", | |
| messages=messages, | |
| client=async_handler, | |
| temperature=0.5, | |
| extraparam="testpassingextraparam", | |
| stream=True, | |
| ) | |
| ) | |
| # Use async list gathering for the response | |
| async def gather_responses(): | |
| return [item async for item in response_stream] | |
| response = asyncio.run(gather_responses()) | |
| assert "dbrx-instruct-071224" in str(response) | |
| assert "chatcmpl" in str(response) | |
| assert len(response) == 4 | |
| assert mock_post.call_args.kwargs["headers"]["Content-Type"] == "application/json" | |
| assert mock_post.call_args.kwargs["headers"]["Authorization"] == f"Bearer {api_key}" | |
| assert mock_post.call_args.kwargs["url"] == f"{base_url}/chat/completions" | |
| assert mock_post.call_args.kwargs["stream"] == True | |
| actual_data = json.loads( | |
| mock_post.call_args.kwargs["data"] | |
| ) # Deserialize the actual data | |
| expected_data = { | |
| "model": "dbrx-instruct-071224", | |
| "messages": messages, | |
| "temperature": 0.5, | |
| "stream": True, | |
| "extraparam": "testpassingextraparam", | |
| } | |
| assert actual_data == expected_data, f"Unexpected JSON data: {actual_data}" | |
| def test_completions_uses_databricks_sdk_if_api_key_and_base_not_specified(monkeypatch): | |
| monkeypatch.delenv("DATABRICKS_API_BASE") | |
| monkeypatch.delenv("DATABRICKS_API_KEY") | |
| from databricks.sdk import WorkspaceClient | |
| from databricks.sdk.config import Config | |
| sync_handler = HTTPHandler() | |
| mock_response = Mock(spec=httpx.Response) | |
| mock_response.status_code = 200 | |
| mock_response.json.return_value = mock_chat_response() | |
| expected_response_json = { | |
| **mock_chat_response(), | |
| **{ | |
| "model": "databricks/dbrx-instruct-071224", | |
| }, | |
| } | |
| base_url = "https://my.workspace.cloud.databricks.com" | |
| api_key = "dapimykey" | |
| headers = { | |
| "Authorization": f"Bearer {api_key}", | |
| } | |
| messages = [{"role": "user", "content": "How are you?"}] | |
| mock_workspace_client: WorkspaceClient = MagicMock() | |
| mock_config: Config = MagicMock() | |
| # Simulate the behavior of the config property and its methods | |
| mock_config.authenticate.side_effect = lambda: headers | |
| mock_config.host = base_url # Assign directly as if it's a property | |
| mock_workspace_client.config = mock_config | |
| with patch( | |
| "databricks.sdk.WorkspaceClient", return_value=mock_workspace_client | |
| ), patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post: | |
| response = litellm.completion( | |
| model="databricks/dbrx-instruct-071224", | |
| messages=messages, | |
| client=sync_handler, | |
| temperature=0.5, | |
| extraparam="testpassingextraparam", | |
| ) | |
| assert response.to_dict() == expected_response_json | |
| assert mock_post.call_args.kwargs["headers"]["Content-Type"] == "application/json" | |
| assert mock_post.call_args.kwargs["headers"]["Authorization"] == f"Bearer {api_key}" | |
| assert mock_post.call_args.kwargs["url"] == f"{base_url}/serving-endpoints/chat/completions" | |
| assert mock_post.call_args.kwargs["stream"] == False | |
| assert mock_post.call_args.kwargs["data"] == json.dumps( | |
| { | |
| "model": "dbrx-instruct-071224", | |
| "messages": messages, | |
| "temperature": 0.5, | |
| "extraparam": "testpassingextraparam", | |
| "stream": False, | |
| } | |
| ) | |
| def test_embeddings_with_sync_http_handler(monkeypatch): | |
| base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints" | |
| api_key = "dapimykey" | |
| monkeypatch.setenv("DATABRICKS_API_BASE", base_url) | |
| monkeypatch.setenv("DATABRICKS_API_KEY", api_key) | |
| sync_handler = HTTPHandler() | |
| mock_response = Mock(spec=httpx.Response) | |
| mock_response.status_code = 200 | |
| mock_response.json.return_value = mock_embedding_response() | |
| inputs = ["Hello", "World"] | |
| with patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post: | |
| response = litellm.embedding( | |
| model="databricks/bge-large-en-v1.5", | |
| input=inputs, | |
| client=sync_handler, | |
| extraparam="testpassingextraparam", | |
| ) | |
| assert response.to_dict() == mock_embedding_response() | |
| mock_post.assert_called_once_with( | |
| f"{base_url}/embeddings", | |
| headers={ | |
| "Authorization": f"Bearer {api_key}", | |
| "Content-Type": "application/json", | |
| }, | |
| data=json.dumps( | |
| { | |
| "model": "bge-large-en-v1.5", | |
| "input": inputs, | |
| "extraparam": "testpassingextraparam", | |
| } | |
| ), | |
| ) | |
| def test_embeddings_with_async_http_handler(monkeypatch): | |
| base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints" | |
| api_key = "dapimykey" | |
| monkeypatch.setenv("DATABRICKS_API_BASE", base_url) | |
| monkeypatch.setenv("DATABRICKS_API_KEY", api_key) | |
| async_handler = AsyncHTTPHandler() | |
| mock_response = Mock(spec=httpx.Response) | |
| mock_response.status_code = 200 | |
| mock_response.json.return_value = mock_embedding_response() | |
| inputs = ["Hello", "World"] | |
| with patch.object( | |
| AsyncHTTPHandler, "post", return_value=mock_response | |
| ) as mock_post: | |
| response = asyncio.run( | |
| litellm.aembedding( | |
| model="databricks/bge-large-en-v1.5", | |
| input=inputs, | |
| client=async_handler, | |
| extraparam="testpassingextraparam", | |
| ) | |
| ) | |
| assert response.to_dict() == mock_embedding_response() | |
| mock_post.assert_called_once_with( | |
| f"{base_url}/embeddings", | |
| headers={ | |
| "Authorization": f"Bearer {api_key}", | |
| "Content-Type": "application/json", | |
| }, | |
| data=json.dumps( | |
| { | |
| "model": "bge-large-en-v1.5", | |
| "input": inputs, | |
| "extraparam": "testpassingextraparam", | |
| } | |
| ), | |
| ) | |
| def test_embeddings_uses_databricks_sdk_if_api_key_and_base_not_specified(monkeypatch): | |
| from databricks.sdk import WorkspaceClient | |
| from databricks.sdk.config import Config | |
| base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints" | |
| api_key = "dapimykey" | |
| monkeypatch.setenv("DATABRICKS_API_BASE", base_url) | |
| monkeypatch.setenv("DATABRICKS_API_KEY", api_key) | |
| sync_handler = HTTPHandler() | |
| mock_response = Mock(spec=httpx.Response) | |
| mock_response.status_code = 200 | |
| mock_response.json.return_value = mock_embedding_response() | |
| base_url = "https://my.workspace.cloud.databricks.com" | |
| api_key = "dapimykey" | |
| headers = { | |
| "Authorization": f"Bearer {api_key}", | |
| } | |
| inputs = ["Hello", "World"] | |
| mock_workspace_client: WorkspaceClient = MagicMock() | |
| mock_config: Config = MagicMock() | |
| # Simulate the behavior of the config property and its methods | |
| mock_config.authenticate.side_effect = lambda: headers | |
| mock_config.host = base_url # Assign directly as if it's a property | |
| mock_workspace_client.config = mock_config | |
| with patch( | |
| "databricks.sdk.WorkspaceClient", return_value=mock_workspace_client | |
| ), patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post: | |
| response = litellm.embedding( | |
| model="databricks/bge-large-en-v1.5", | |
| input=inputs, | |
| client=sync_handler, | |
| extraparam="testpassingextraparam", | |
| ) | |
| assert response.to_dict() == mock_embedding_response() | |
| mock_post.assert_called_once_with( | |
| f"{base_url}/serving-endpoints/embeddings", | |
| headers={ | |
| "Authorization": f"Bearer {api_key}", | |
| "Content-Type": "application/json", | |
| }, | |
| data=json.dumps( | |
| { | |
| "model": "bge-large-en-v1.5", | |
| "input": inputs, | |
| "extraparam": "testpassingextraparam", | |
| } | |
| ), | |
| ) | |
| class TestDatabricksCompletion(BaseLLMChatTest, BaseAnthropicChatTest): | |
| def get_base_completion_call_args(self) -> dict: | |
| return {"model": "databricks/databricks-claude-3-7-sonnet"} | |
| def get_base_completion_call_args_with_thinking(self) -> dict: | |
| return { | |
| "model": "databricks/databricks-claude-3-7-sonnet", | |
| "thinking": {"type": "enabled", "budget_tokens": 1024}, | |
| } | |
| def test_pdf_handling(self, pdf_messages): | |
| pytest.skip("Databricks does not support PDF handling") | |
| def test_tool_call_no_arguments(self, tool_call_no_arguments): | |
| """Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833""" | |
| pytest.skip("Databricks is openai compatible") | |
| async def test_databricks_embeddings(sync_mode): | |
| import openai | |
| try: | |
| litellm.set_verbose = True | |
| litellm.drop_params = True | |
| if sync_mode: | |
| response = litellm.embedding( | |
| model="databricks/databricks-bge-large-en", | |
| input=["good morning from litellm"], | |
| instruction="Represent this sentence for searching relevant passages:", | |
| ) | |
| else: | |
| response = await litellm.aembedding( | |
| model="databricks/databricks-bge-large-en", | |
| input=["good morning from litellm"], | |
| instruction="Represent this sentence for searching relevant passages:", | |
| ) | |
| print(f"response: {response}") | |
| openai.types.CreateEmbeddingResponse.model_validate( | |
| response.model_dump(), strict=True | |
| ) | |
| # stubbed endpoint is setup to return this | |
| # assert response.data[0]["embedding"] == [0.1, 0.2, 0.3] | |
| except Exception as e: | |
| pytest.fail(f"Error occurred: {e}") | |