Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| """ | |
| @author:XuMing([email protected]) | |
| @description: | |
| """ | |
| import platform | |
| from loguru import logger | |
| from src.base_model import BaseLLMModel | |
| from src.presets import LOCAL_MODELS | |
| class ChatGLMClient(BaseLLMModel): | |
| def __init__(self, model_name, user_name=""): | |
| super().__init__(model_name=model_name, user=user_name) | |
| import torch | |
| from transformers import AutoModel, AutoTokenizer | |
| system_name = platform.system() | |
| logger.info(f"Loading model from {model_name}") | |
| if model_name in LOCAL_MODELS: | |
| model_path = LOCAL_MODELS[model_name] | |
| else: | |
| model_path = model_name | |
| self.CHATGLM_TOKENIZER = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
| quantified = False | |
| if "int4" in model_name: | |
| quantified = True | |
| model = AutoModel.from_pretrained(model_path, trust_remote_code=True, device_map='auto', torch_dtype='auto') | |
| if torch.cuda.is_available(): | |
| logger.info("CUDA is available, using CUDA") | |
| model = model.half().cuda() | |
| # mps加速还存在一些问题,暂时不使用 | |
| elif system_name == "Darwin" and model_path is not None and not quantified: | |
| logger.info("Running on macOS, using MPS") | |
| # running on macOS and model already downloaded | |
| model = model.half().to("mps") | |
| else: | |
| logger.info("GPU is not available, using CPU") | |
| model = model.float() | |
| model = model.eval() | |
| logger.info(f"Model loaded from {model_path}") | |
| self.CHATGLM_MODEL = model | |
| def _get_glm3_style_input(self): | |
| history = self.history | |
| query = history.pop()["content"] | |
| return history, query | |
| def _get_glm2_style_input(self): | |
| history = [x["content"] for x in self.history] | |
| query = history.pop() | |
| logger.debug(f"{history}") | |
| assert len(history) % 2 == 0, f"History should be even length. current history is: {history}" | |
| history = [[history[i], history[i + 1]] | |
| for i in range(0, len(history), 2)] | |
| return history, query | |
| def _get_glm_style_input(self): | |
| if "glm2" in self.model_name: | |
| return self._get_glm2_style_input() | |
| else: | |
| return self._get_glm3_style_input() | |
| def get_answer_at_once(self): | |
| history, query = self._get_glm_style_input() | |
| logger.debug(f"{history}") | |
| response, _ = self.CHATGLM_MODEL.chat( | |
| self.CHATGLM_TOKENIZER, query, history=history) | |
| return response, len(response) | |
| def get_answer_stream_iter(self): | |
| history, query = self._get_glm_style_input() | |
| logger.debug(f"{history}") | |
| for response, history in self.CHATGLM_MODEL.stream_chat( | |
| self.CHATGLM_TOKENIZER, | |
| query, | |
| history, | |
| max_length=self.token_upper_limit, | |
| top_p=self.top_p, | |
| temperature=self.temperature, | |
| ): | |
| yield response | |