Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| """ | |
| Get model client from model name | |
| """ | |
| import json | |
| import colorama | |
| import gradio as gr | |
| from loguru import logger | |
| from src import config | |
| from src.base_model import ModelType | |
| from src.utils import ( | |
| hide_middle_chars, | |
| i18n, | |
| ) | |
| def get_model( | |
| model_name, | |
| lora_model_path=None, | |
| access_key=None, | |
| temperature=None, | |
| top_p=None, | |
| system_prompt=None, | |
| user_name="", | |
| original_model=None, | |
| ): | |
| msg = i18n("模型设置为了:") + f" {model_name}" | |
| model_type = ModelType.get_type(model_name) | |
| lora_choices = ["No LoRA"] | |
| if model_type != ModelType.OpenAI: | |
| config.local_embedding = True | |
| model = original_model | |
| chatbot = gr.Chatbot.update(label=model_name) | |
| try: | |
| if model_type == ModelType.OpenAI: | |
| logger.info(f"正在加载OpenAI模型: {model_name}") | |
| from src.openai_client import OpenAIClient | |
| model = OpenAIClient( | |
| model_name=model_name, | |
| api_key=access_key, | |
| system_prompt=system_prompt, | |
| user_name=user_name, | |
| ) | |
| elif model_type == ModelType.OpenAIVision: | |
| logger.info(f"正在加载OpenAI Vision模型: {model_name}") | |
| from src.openai_client import OpenAIVisionClient | |
| model = OpenAIVisionClient(model_name, api_key=access_key, user_name=user_name) | |
| elif model_type == ModelType.ChatGLM: | |
| logger.info(f"正在加载ChatGLM模型: {model_name}") | |
| from src.chatglm import ChatGLMClient | |
| model = ChatGLMClient(model_name, user_name=user_name) | |
| elif model_type == ModelType.LLaMA: | |
| logger.info(f"正在加载LLaMA模型: {model_name}") | |
| from src.llama import LLaMAClient | |
| model = LLaMAClient(model_name, user_name=user_name) | |
| elif model_type == ModelType.ZhipuAI: # todo: fix zhipu bug | |
| logger.info(f"正在加载ZhipuAI模型: {model_name}") | |
| from src.zhipu_client import ZhipuAIClient | |
| model = ZhipuAIClient( | |
| model_name=model_name, | |
| api_key=access_key, | |
| system_prompt=system_prompt, | |
| user_name=user_name, | |
| ) | |
| elif model_type == ModelType.Unknown: | |
| raise ValueError(f"未知模型: {model_name}") | |
| except Exception as e: | |
| logger.error(e) | |
| logger.info(msg) | |
| presudo_key = hide_middle_chars(access_key) | |
| if original_model is not None and model is not None: | |
| model.history = original_model.history | |
| model.history_file_path = original_model.history_file_path | |
| return model, msg, chatbot, gr.Dropdown.update(choices=lora_choices, visible=False), access_key, presudo_key | |
| if __name__ == "__main__": | |
| with open("../config.json", "r") as f: | |
| openai_api_key = json.load(f)["openai_api_key"] | |
| print('key:', openai_api_key) | |
| client = get_model(model_name="gpt-3.5-turbo", access_key=openai_api_key)[0] | |
| chatbot = [] | |
| stream = False | |
| # 测试账单功能 | |
| logger.info(colorama.Back.GREEN + "测试账单功能" + colorama.Back.RESET) | |
| logger.info(client.billing_info()) | |
| # 测试问答 | |
| logger.info(colorama.Back.GREEN + "测试问答" + colorama.Back.RESET) | |
| question = "巴黎是中国的首都吗?" | |
| for i in client.predict(inputs=question, chatbot=chatbot, stream=stream): | |
| logger.info(i) | |
| logger.info(f"测试问答后history : {client.history}") | |
| # 测试记忆力 | |
| logger.info(colorama.Back.GREEN + "测试记忆力" + colorama.Back.RESET) | |
| question = "我刚刚问了你什么问题?" | |
| for i in client.predict(inputs=question, chatbot=chatbot, stream=stream): | |
| logger.info(i) | |
| logger.info(f"测试记忆力后history : {client.history}") | |
| # 测试重试功能 | |
| logger.info(colorama.Back.GREEN + "测试重试功能" + colorama.Back.RESET) | |
| for i in client.retry(chatbot=chatbot, stream=stream): | |
| logger.info(i) | |
| logger.info(f"重试后history : {client.history}") | |