Spaces:
Runtime error
Runtime error
| # -*- coding:utf-8 -*- | |
| import csv | |
| import datetime | |
| import getpass | |
| import hashlib | |
| import html | |
| import json | |
| import os | |
| import pickle | |
| import re | |
| import threading | |
| from enum import Enum | |
| from typing import List, Union | |
| from typing import TYPE_CHECKING | |
| import colorama | |
| import gradio as gr | |
| import pandas as pd | |
| import requests | |
| import tiktoken | |
| from loguru import logger | |
| from markdown import markdown | |
| from pygments import highlight | |
| from pygments.formatters import HtmlFormatter | |
| from pygments.lexers import get_lexer_by_name | |
| from pypinyin import lazy_pinyin | |
| from src.config import retrieve_proxy, hide_history_when_not_logged_in, config_file | |
| from src.presets import ALREADY_CONVERTED_MARK, HISTORY_DIR, TEMPLATES_DIR, i18n, LOCAL_MODELS, ONLINE_MODELS | |
| from src.shared import state | |
| if TYPE_CHECKING: | |
| from typing import TypedDict | |
| class DataframeData(TypedDict): | |
| headers: List[str] | |
| data: List[List[Union[str, int, bool]]] | |
| def predict(current_model, *args): | |
| if current_model: | |
| iter = current_model.predict(*args) | |
| for i in iter: | |
| yield i | |
| def billing_info(current_model): | |
| if current_model: | |
| return current_model.billing_info() | |
| def set_key(current_model, *args): | |
| return current_model.set_key(*args) | |
| def load_chat_history(current_model, *args): | |
| return current_model.load_chat_history(*args) | |
| def delete_chat_history(current_model, *args): | |
| return current_model.delete_chat_history(*args) | |
| def interrupt(current_model, *args): | |
| return current_model.interrupt(*args) | |
| def reset(current_model, *args): | |
| if current_model: | |
| return current_model.reset(*args) | |
| def retry(current_model, *args): | |
| iter = current_model.retry(*args) | |
| for i in iter: | |
| yield i | |
| def delete_first_conversation(current_model, *args): | |
| return current_model.delete_first_conversation(*args) | |
| def delete_last_conversation(current_model, *args): | |
| return current_model.delete_last_conversation(*args) | |
| def set_system_prompt(current_model, *args): | |
| return current_model.set_system_prompt(*args) | |
| def rename_chat_history(current_model, *args): | |
| return current_model.rename_chat_history(*args) | |
| def auto_name_chat_history(current_model, *args): | |
| if current_model: | |
| return current_model.auto_name_chat_history(*args) | |
| def export_markdown(current_model, *args): | |
| return current_model.export_markdown(*args) | |
| def upload_chat_history(current_model, *args): | |
| return current_model.load_chat_history(*args) | |
| def set_token_upper_limit(current_model, *args): | |
| return current_model.set_token_upper_limit(*args) | |
| def set_temperature(current_model, *args): | |
| current_model.set_temperature(*args) | |
| def set_top_p(current_model, *args): | |
| current_model.set_top_p(*args) | |
| def set_n_choices(current_model, *args): | |
| current_model.set_n_choices(*args) | |
| def set_stop_sequence(current_model, *args): | |
| current_model.set_stop_sequence(*args) | |
| def set_max_tokens(current_model, *args): | |
| current_model.set_max_tokens(*args) | |
| def set_presence_penalty(current_model, *args): | |
| current_model.set_presence_penalty(*args) | |
| def set_frequency_penalty(current_model, *args): | |
| current_model.set_frequency_penalty(*args) | |
| def set_logit_bias(current_model, *args): | |
| current_model.set_logit_bias(*args) | |
| def set_user_identifier(current_model, *args): | |
| current_model.set_user_identifier(*args) | |
| def set_single_turn(current_model, *args): | |
| current_model.set_single_turn(*args) | |
| def handle_file_upload(current_model, *args): | |
| return current_model.handle_file_upload(*args) | |
| def handle_summarize_index(current_model, *args): | |
| return current_model.summarize_index(*args) | |
| def like(current_model, *args): | |
| return current_model.like(*args) | |
| def dislike(current_model, *args): | |
| return current_model.dislike(*args) | |
| def count_token(input_str): | |
| encoding = tiktoken.get_encoding("cl100k_base") | |
| if type(input_str) == dict: | |
| input_str = f"role: {input_str['role']}, content: {input_str['content']}" | |
| length = len(encoding.encode(input_str)) | |
| return length | |
| def markdown_to_html_with_syntax_highlight(md_str): # deprecated | |
| def replacer(match): | |
| lang = match.group(1) or "text" | |
| code = match.group(2) | |
| try: | |
| lexer = get_lexer_by_name(lang, stripall=True) | |
| except ValueError: | |
| lexer = get_lexer_by_name("text", stripall=True) | |
| formatter = HtmlFormatter() | |
| highlighted_code = highlight(code, lexer, formatter) | |
| return f'<pre><code class="{lang}">{highlighted_code}</code></pre>' | |
| code_block_pattern = r"```(\w+)?\n([\s\S]+?)\n```" | |
| md_str = re.sub(code_block_pattern, replacer, md_str, flags=re.MULTILINE) | |
| html_str = markdown(md_str) | |
| return html_str | |
| def normalize_markdown(md_text: str) -> str: # deprecated | |
| lines = md_text.split("\n") | |
| normalized_lines = [] | |
| inside_list = False | |
| for i, line in enumerate(lines): | |
| if re.match(r"^(\d+\.|-|\*|\+)\s", line.strip()): | |
| if not inside_list and i > 0 and lines[i - 1].strip() != "": | |
| normalized_lines.append("") | |
| inside_list = True | |
| normalized_lines.append(line) | |
| elif inside_list and line.strip() == "": | |
| if i < len(lines) - 1 and not re.match( | |
| r"^(\d+\.|-|\*|\+)\s", lines[i + 1].strip() | |
| ): | |
| normalized_lines.append(line) | |
| continue | |
| else: | |
| inside_list = False | |
| normalized_lines.append(line) | |
| return "\n".join(normalized_lines) | |
| def convert_mdtext(md_text): # deprecated | |
| code_block_pattern = re.compile(r"```(.*?)(?:```|$)", re.DOTALL) | |
| inline_code_pattern = re.compile(r"`(.*?)`", re.DOTALL) | |
| code_blocks = code_block_pattern.findall(md_text) | |
| non_code_parts = code_block_pattern.split(md_text)[::2] | |
| result = [] | |
| raw = f'<div class="raw-message hideM">{html.escape(md_text)}</div>' | |
| for non_code, code in zip(non_code_parts, code_blocks + [""]): | |
| if non_code.strip(): | |
| non_code = normalize_markdown(non_code) | |
| result.append(markdown(non_code, extensions=["tables"])) | |
| if code.strip(): | |
| # _, code = detect_language(code) # 暂时去除代码高亮功能,因为在大段代码的情况下会出现问题 | |
| # code = code.replace("\n\n", "\n") # 暂时去除代码中的空行,因为在大段代码的情况下会出现问题 | |
| code = f"\n```{code}\n\n```" | |
| code = markdown_to_html_with_syntax_highlight(code) | |
| result.append(code) | |
| result = "".join(result) | |
| output = f'<div class="md-message">{result}</div>' | |
| output += raw | |
| output += ALREADY_CONVERTED_MARK | |
| return output | |
| def clip_rawtext(chat_message, need_escape=True): | |
| # first, clip hr line | |
| hr_pattern = r'\n\n<hr class="append-display no-in-raw" />(.*?)' | |
| hr_match = re.search(hr_pattern, chat_message, re.DOTALL) | |
| message_clipped = chat_message[: hr_match.start()] if hr_match else chat_message | |
| # second, avoid agent-prefix being escaped | |
| agent_prefix_pattern = ( | |
| r'(<!-- S O PREFIX -->.*?<!-- E O PREFIX -->)' | |
| ) | |
| # agent_matches = re.findall(agent_prefix_pattern, message_clipped) | |
| agent_parts = re.split(agent_prefix_pattern, message_clipped, flags=re.DOTALL) | |
| final_message = "" | |
| for i, part in enumerate(agent_parts): | |
| if i % 2 == 0: | |
| if part != "" and part != "\n": | |
| final_message += ( | |
| f'<pre class="fake-pre">{escape_markdown(part)}</pre>' | |
| if need_escape | |
| else f'<pre class="fake-pre">{part}</pre>' | |
| ) | |
| else: | |
| part = part.replace(' data-fancybox="gallery"', '') | |
| final_message += part | |
| return final_message | |
| def convert_bot_before_marked(chat_message): | |
| """ | |
| 注意不能给输出加缩进, 否则会被marked解析成代码块 | |
| """ | |
| if '<div class="md-message">' in chat_message: | |
| return chat_message | |
| else: | |
| raw = f'<div class="raw-message hideM">{clip_rawtext(chat_message)}</div>' | |
| # really_raw = f'{START_OF_OUTPUT_MARK}<div class="really-raw hideM">{clip_rawtext(chat_message, need_escape=False)}\n</div>{END_OF_OUTPUT_MARK}' | |
| code_block_pattern = re.compile(r"```(.*?)(?:```|$)", re.DOTALL) | |
| code_blocks = code_block_pattern.findall(chat_message) | |
| non_code_parts = code_block_pattern.split(chat_message)[::2] | |
| result = [] | |
| for non_code, code in zip(non_code_parts, code_blocks + [""]): | |
| if non_code.strip(): | |
| result.append(non_code) | |
| if code.strip(): | |
| code = f"\n```{code}\n```" | |
| result.append(code) | |
| result = "".join(result) | |
| md = f'<div class="md-message">\n\n{result}\n</div>' | |
| return raw + md | |
| def convert_user_before_marked(chat_message): | |
| if '<div class="user-message">' in chat_message: | |
| return chat_message | |
| else: | |
| return f'<div class="user-message">{escape_markdown(chat_message)}</div>' | |
| def escape_markdown(text): | |
| """ | |
| Escape Markdown special characters to HTML-safe equivalents. | |
| """ | |
| escape_chars = { | |
| # ' ': ' ', | |
| "_": "_", | |
| "*": "*", | |
| "[": "[", | |
| "]": "]", | |
| "(": "(", | |
| ")": ")", | |
| "{": "{", | |
| "}": "}", | |
| "#": "#", | |
| "+": "+", | |
| "-": "-", | |
| ".": ".", | |
| "!": "!", | |
| "`": "`", | |
| ">": ">", | |
| "<": "<", | |
| "|": "|", | |
| "$": "$", | |
| ":": ":", | |
| "\n": "<br>", | |
| } | |
| text = text.replace(" ", " ") | |
| return "".join(escape_chars.get(c, c) for c in text) | |
| def convert_asis(userinput): # deprecated | |
| return ( | |
| f'<p style="white-space:pre-wrap;">{html.escape(userinput)}</p>' | |
| + ALREADY_CONVERTED_MARK | |
| ) | |
| def detect_converted_mark(userinput): # deprecated | |
| try: | |
| if userinput.endswith(ALREADY_CONVERTED_MARK): | |
| return True | |
| else: | |
| return False | |
| except: | |
| return True | |
| def detect_language(code): # deprecated | |
| if code.startswith("\n"): | |
| first_line = "" | |
| else: | |
| first_line = code.strip().split("\n", 1)[0] | |
| language = first_line.lower() if first_line else "" | |
| code_without_language = code[len(first_line):].lstrip() if first_line else code | |
| return language, code_without_language | |
| def construct_text(role, text): | |
| return {"role": role, "content": text} | |
| def construct_user(text): | |
| return construct_text("user", text) | |
| def construct_system(text): | |
| return construct_text("system", text) | |
| def construct_assistant(text): | |
| return construct_text("assistant", text) | |
| def save_file(filename, model, chatbot): | |
| system = model.system_prompt | |
| history = model.history | |
| user_name = model.user_name | |
| os.makedirs(os.path.join(HISTORY_DIR, user_name), exist_ok=True) | |
| if filename is None: | |
| filename = new_auto_history_filename(user_name) | |
| if filename.endswith(".md"): | |
| filename = filename[:-3] | |
| if not filename.endswith(".json") and not filename.endswith(".md"): | |
| filename += ".json" | |
| if filename == ".json": | |
| raise Exception("文件名不能为空") | |
| json_s = { | |
| "system": system, | |
| "history": history, | |
| "chatbot": chatbot, | |
| "model_name": model.model_name, | |
| "single_turn": model.single_turn, | |
| "temperature": model.temperature, | |
| "top_p": model.top_p, | |
| "n_choices": model.n_choices, | |
| "stop_sequence": model.stop_sequence, | |
| "token_upper_limit": model.token_upper_limit, | |
| "max_generation_token": model.max_generation_token, | |
| "presence_penalty": model.presence_penalty, | |
| "frequency_penalty": model.frequency_penalty, | |
| "logit_bias": model.logit_bias, | |
| "user_identifier": model.user_identifier, | |
| "metadata": model.metadata, | |
| } | |
| if not filename == os.path.basename(filename): | |
| history_file_path = filename | |
| else: | |
| history_file_path = os.path.join(HISTORY_DIR, user_name, filename) | |
| with open(history_file_path, "w", encoding="utf-8") as f: | |
| json.dump(json_s, f, ensure_ascii=False, indent=4) | |
| filename = os.path.basename(filename) | |
| filename_md = filename[:-5] + ".md" | |
| md_s = f"system: \n- {system} \n" | |
| for data in history: | |
| md_s += f"\n{data['role']}: \n- {data['content']} \n" | |
| with open( | |
| os.path.join(HISTORY_DIR, user_name, filename_md), "w", encoding="utf8" | |
| ) as f: | |
| f.write(md_s) | |
| return os.path.join(HISTORY_DIR, user_name, filename) | |
| def sorted_by_pinyin(list): | |
| return sorted(list, key=lambda char: lazy_pinyin(char)[0][0]) | |
| def sorted_by_last_modified_time(list, dir): | |
| return sorted( | |
| list, key=lambda char: os.path.getctime(os.path.join(dir, char)), reverse=True | |
| ) | |
| def get_file_names_by_type(dir, filetypes=[".json"]): | |
| os.makedirs(dir, exist_ok=True) | |
| files = [] | |
| for type in filetypes: | |
| files += [f for f in os.listdir(dir) if f.endswith(type)] | |
| return files | |
| def get_file_names_by_pinyin(dir, filetypes=[".json"]): | |
| files = get_file_names_by_type(dir, filetypes) | |
| if files != [""]: | |
| files = sorted_by_pinyin(files) | |
| logger.debug(f"files are:{files}") | |
| return files | |
| def get_file_names_dropdown_by_pinyin(dir, filetypes=[".json"]): | |
| files = get_file_names_by_pinyin(dir, filetypes) | |
| return gr.Dropdown.update(choices=files) | |
| def get_file_names_by_last_modified_time(dir, filetypes=[".json"]): | |
| files = get_file_names_by_type(dir, filetypes) | |
| if files != [""]: | |
| files = sorted_by_last_modified_time(files, dir) | |
| logger.debug(f"files are:{files}") | |
| return files | |
| def get_history_names(user_name=""): | |
| logger.debug(f"从用户 {user_name} 中获取历史记录文件名列表") | |
| if user_name == "" and hide_history_when_not_logged_in: | |
| return [] | |
| else: | |
| history_files = get_file_names_by_last_modified_time( | |
| os.path.join(HISTORY_DIR, user_name) | |
| ) | |
| history_files = [f[: f.rfind(".")] for f in history_files] | |
| return history_files | |
| def get_first_history_name(user_name=""): | |
| history_names = get_history_names(user_name) | |
| return history_names[0] if history_names else "" | |
| def get_history_list(user_name=""): | |
| history_names = get_history_names(user_name) | |
| return gr.Radio.update(choices=history_names) | |
| def init_history_list(user_name="", prepend=""): | |
| history_names = get_history_names(user_name) | |
| if prepend and prepend not in history_names: | |
| history_names.insert(0, prepend) | |
| return gr.Radio.update( | |
| choices=history_names, value=history_names[0] if history_names else "" | |
| ) | |
| def filter_history(user_name, keyword): | |
| history_names = get_history_names(user_name) | |
| try: | |
| history_names = [name for name in history_names if re.search(keyword, name)] | |
| return gr.update(choices=history_names) | |
| except: | |
| return gr.update(choices=history_names) | |
| def load_template(filename, mode=0): | |
| logger.debug(f"加载模板文件{filename},模式为{mode}(0为返回字典和下拉菜单,1为返回下拉菜单,2为返回字典)") | |
| lines = [] | |
| if filename.endswith(".json"): | |
| with open(os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8") as f: | |
| lines = json.load(f) | |
| lines = [[i["act"], i["prompt"]] for i in lines] | |
| else: | |
| with open( | |
| os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8" | |
| ) as csvfile: | |
| reader = csv.reader(csvfile) | |
| lines = list(reader) | |
| lines = lines[1:] | |
| if mode == 1: | |
| return sorted_by_pinyin([row[0] for row in lines]) | |
| elif mode == 2: | |
| return {row[0]: row[1] for row in lines} | |
| else: | |
| choices = sorted_by_pinyin([row[0] for row in lines]) | |
| return {row[0]: row[1] for row in lines}, gr.Dropdown.update(choices=choices) | |
| def get_template_names(): | |
| logger.debug("获取模板文件名列表") | |
| return get_file_names_by_pinyin(TEMPLATES_DIR, filetypes=[".csv", "json"]) | |
| def get_template_dropdown(): | |
| logger.debug("获取模板下拉菜单") | |
| template_names = get_template_names() | |
| return gr.Dropdown.update(choices=template_names) | |
| def get_template_content(templates, selection, original_system_prompt): | |
| logger.debug(f"应用模板中,选择为{selection},原始系统提示为{original_system_prompt}") | |
| try: | |
| return templates[selection] | |
| except: | |
| return original_system_prompt | |
| def reset_textbox(): | |
| logger.debug("重置文本框") | |
| return gr.update(value="") | |
| def reset_default(): | |
| default_host = state.reset_api_host() | |
| retrieve_proxy("") | |
| return gr.update(value=default_host), gr.update(value=""), "API-Host 和代理已重置" | |
| def change_api_host(host): | |
| state.set_api_host(host) | |
| msg = f"API-Host更改为了{host}" | |
| logger.info(msg) | |
| return msg | |
| def change_proxy(proxy): | |
| retrieve_proxy(proxy) | |
| os.environ["HTTPS_PROXY"] = proxy | |
| msg = f"代理更改为了{proxy}" | |
| logger.info(msg) | |
| return msg | |
| def hide_middle_chars(s): | |
| if s is None: | |
| return "" | |
| if len(s) <= 8: | |
| return s | |
| else: | |
| head = s[:4] | |
| tail = s[-4:] | |
| hidden = "*" * (len(s) - 8) | |
| return head + hidden + tail | |
| def submit_key(key): | |
| key = key.strip() | |
| msg = f"API密钥更改为了{hide_middle_chars(key)}" | |
| logger.info(msg) | |
| return key, msg | |
| def replace_today(prompt): | |
| today = datetime.datetime.today().strftime("%Y-%m-%d") | |
| return prompt.replace("{current_date}", today) | |
| SERVER_GEO_IP_MSG = None | |
| FETCHING_IP = False | |
| def get_geoip(): | |
| global SERVER_GEO_IP_MSG, FETCHING_IP | |
| # 如果已经获取了IP信息,则直接返回 | |
| if SERVER_GEO_IP_MSG is not None: | |
| return SERVER_GEO_IP_MSG | |
| # 如果正在获取IP信息,则返回等待消息 | |
| if FETCHING_IP: | |
| return i18n("IP地址信息正在获取中,请稍候...") | |
| # 定义一个内部函数用于在新线程中执行IP信息的获取 | |
| def fetch_ip(): | |
| global SERVER_GEO_IP_MSG, FETCHING_IP | |
| try: | |
| with retrieve_proxy(): | |
| response = requests.get("https://ipapi.co/json/", timeout=5) | |
| data = response.json() | |
| except: | |
| data = {"error": True, "reason": "连接ipapi失败"} | |
| if "error" in data.keys(): | |
| logger.warning(f"无法获取IP地址信息。\n{data}") | |
| SERVER_GEO_IP_MSG = i18n("你可以使用聊天功能。") | |
| else: | |
| country = data["country_name"] | |
| if country == "China": | |
| SERVER_GEO_IP_MSG = "**您的IP区域:中国。**" | |
| else: | |
| SERVER_GEO_IP_MSG = i18n("您的IP区域:") + f"{country}。" | |
| logger.info(SERVER_GEO_IP_MSG) | |
| FETCHING_IP = False | |
| # 设置正在获取IP信息的标志 | |
| FETCHING_IP = True | |
| # 启动一个新线程来获取IP信息 | |
| thread = threading.Thread(target=fetch_ip) | |
| thread.start() | |
| # 返回一个默认消息,真正的IP信息将由新线程更新 | |
| return i18n("正在获取IP地址信息,请稍候...") | |
| def find_n(lst, max_num): | |
| n = len(lst) | |
| total = sum(lst) | |
| if total < max_num: | |
| return n | |
| for i in range(len(lst)): | |
| if total - lst[i] < max_num: | |
| return n - i - 1 | |
| total = total - lst[i] | |
| return 1 | |
| def start_outputing(): | |
| logger.debug("显示取消按钮,隐藏发送按钮") | |
| return gr.Button.update(visible=False), gr.Button.update(visible=True) | |
| def end_outputing(): | |
| return ( | |
| gr.Button.update(visible=True), | |
| gr.Button.update(visible=False), | |
| ) | |
| def cancel_outputing(): | |
| logger.info("中止输出……") | |
| state.interrupt() | |
| def transfer_input(inputs): | |
| # 一次性返回,降低延迟 | |
| textbox = reset_textbox() | |
| outputing = start_outputing() | |
| return ( | |
| inputs, | |
| gr.update(value=""), | |
| gr.Button.update(visible=False), | |
| gr.Button.update(visible=True), | |
| ) | |
| def update_chuanhu(): | |
| return gr.Markdown.update(value=i18n("done")) | |
| def add_source_numbers(lst, source_name="Source", use_source=True): | |
| if use_source: | |
| return [ | |
| f'[{idx + 1}]\t "{item[0]}"\n{source_name}: {item[1]}' | |
| for idx, item in enumerate(lst) | |
| ] | |
| else: | |
| return [f'[{idx + 1}]\t "{item}"' for idx, item in enumerate(lst)] | |
| def add_details(lst): | |
| nodes = [] | |
| for index, txt in enumerate(lst): | |
| brief = txt[:25].replace("\n", "") | |
| nodes.append(f"<details><summary>{brief}...</summary><p>{txt}</p></details>") | |
| return nodes | |
| def sheet_to_string(sheet, sheet_name=None): | |
| result = [] | |
| for index, row in sheet.iterrows(): | |
| row_string = "" | |
| for column in sheet.columns: | |
| row_string += f"{column}: {row[column]}, " | |
| row_string = row_string.rstrip(", ") | |
| row_string += "." | |
| result.append(row_string) | |
| return result | |
| def excel_to_string(file_path): | |
| # 读取Excel文件中的所有工作表 | |
| excel_file = pd.read_excel(file_path, engine="openpyxl", sheet_name=None) | |
| # 初始化结果字符串 | |
| result = [] | |
| # 遍历每一个工作表 | |
| for sheet_name, sheet_data in excel_file.items(): | |
| # 处理当前工作表并添加到结果字符串 | |
| result += sheet_to_string(sheet_data, sheet_name=sheet_name) | |
| return result | |
| def get_last_day_of_month(any_day): | |
| # The day 28 exists in every month. 4 days later, it's always next month | |
| next_month = any_day.replace(day=28) + datetime.timedelta(days=4) | |
| # subtracting the number of the current day brings us back one month | |
| return next_month - datetime.timedelta(days=next_month.day) | |
| def get_model_source(model_name, alternative_source): | |
| if model_name == "gpt2-medium": | |
| return "https://huggingface.co/gpt2-medium" | |
| def refresh_ui_elements_on_load(current_model, selected_model_name, user_name): | |
| current_model.set_user_identifier(user_name) | |
| return toggle_like_btn_visibility(selected_model_name), *current_model.auto_load() | |
| def toggle_like_btn_visibility(selected_model_name): | |
| if selected_model_name == "xmchat": | |
| return gr.update(visible=True) | |
| else: | |
| return gr.update(visible=False) | |
| def get_corresponding_file_type_by_model_name(selected_model_name): | |
| if selected_model_name in ["xmchat", "GPT4 Vision"]: | |
| return ["image"] | |
| else: | |
| return [".pdf", ".docx", ".pptx", ".epub", ".xlsx", ".txt", "text"] | |
| def new_auto_history_filename(username): | |
| latest_file = get_first_history_name(username) | |
| if latest_file: | |
| with open(os.path.join(HISTORY_DIR, username, latest_file + ".json"), | |
| "r", | |
| encoding="utf-8", | |
| ) as f: | |
| if len(f.read()) == 0: | |
| return latest_file | |
| now = i18n("新对话 ") + datetime.datetime.now().strftime("%m-%d %H-%M") | |
| return f"{now}.json" | |
| def get_history_filepath(username): | |
| dirname = os.path.join(HISTORY_DIR, username) | |
| os.makedirs(dirname, exist_ok=True) | |
| latest_file = get_first_history_name(username) | |
| if not latest_file: | |
| latest_file = new_auto_history_filename(username) | |
| latest_file = os.path.join(dirname, latest_file) | |
| return latest_file | |
| def beautify_err_msg(err_msg): | |
| if "insufficient_quota" in err_msg: | |
| return i18n("剩余配额不足") | |
| if "The model `gpt-4` does not exist" in err_msg: | |
| return i18n("你没有权限访问 GPT4") | |
| if "Resource not found" in err_msg: | |
| return i18n("请查看 config_example.json,配置 Azure OpenAI") | |
| return err_msg | |
| def auth_from_conf(username, password): | |
| try: | |
| with open(config_file, encoding="utf-8") as f: | |
| conf = json.load(f) | |
| usernames, passwords = [i[0] for i in conf["users"]], [ | |
| i[1] for i in conf["users"] | |
| ] | |
| if username in usernames: | |
| if passwords[usernames.index(username)] == password: | |
| return True | |
| return False | |
| except: | |
| return False | |
| def get_files_hash(file_src=None, file_paths=None): | |
| if file_src: | |
| file_paths = [x.name for x in file_src] | |
| file_paths.sort(key=lambda x: os.path.basename(x)) | |
| md5_hash = hashlib.md5() | |
| for file_path in file_paths: | |
| with open(file_path, "rb") as f: | |
| while chunk := f.read(8192): | |
| md5_hash.update(chunk) | |
| return md5_hash.hexdigest() | |
| def myprint(**args): | |
| print(args) | |
| def replace_special_symbols(string, replace_string=" "): | |
| # 定义正则表达式,匹配所有特殊符号 | |
| pattern = r"[!@#$%^&*()<>?/\|}{~:]" | |
| new_string = re.sub(pattern, replace_string, string) | |
| return new_string | |
| class ConfigType(Enum): | |
| Bool = 1 | |
| String = 2 | |
| Password = 3 | |
| Number = 4 | |
| ListOfStrings = 5 | |
| class ConfigItem: | |
| def __init__(self, key, name, default=None, type=ConfigType.String) -> None: | |
| self.key = key | |
| self.name = name | |
| self.default = default | |
| self.type = type | |
| def generate_prompt_string(config_item): | |
| if config_item.default is not None: | |
| return ( | |
| i18n("请输入 ") | |
| + colorama.Fore.GREEN | |
| + i18n(config_item.name) | |
| + colorama.Style.RESET_ALL | |
| + i18n(",默认为 ") | |
| + colorama.Fore.GREEN | |
| + str(config_item.default) | |
| + colorama.Style.RESET_ALL | |
| + i18n(":") | |
| ) | |
| else: | |
| return ( | |
| i18n("请输入 ") | |
| + colorama.Fore.GREEN | |
| + i18n(config_item.name) | |
| + colorama.Style.RESET_ALL | |
| + i18n(":") | |
| ) | |
| def generate_result_string(config_item, config_value): | |
| return ( | |
| i18n("你设置了 ") | |
| + colorama.Fore.CYAN | |
| + i18n(config_item.name) | |
| + colorama.Style.RESET_ALL | |
| + i18n(" 为: ") | |
| + config_value | |
| ) | |
| class SetupWizard: | |
| def __init__(self, file_path=config_file) -> None: | |
| self.config = {} | |
| self.file_path = file_path | |
| language = input( | |
| '请问是否需要更改语言?可选:"auto", "zh_CN", "en_US", "ja_JP", "ko_KR", "sv_SE", "ru_RU", "vi_VN"\nChange the language? Options: "auto", "zh_CN", "en_US", "ja_JP", "ko_KR", "sv_SE", "ru_RU", "vi_VN"\n目前正在使用中文(zh_CN)\nCurrently using Chinese(zh_CN)\n如果需要,请输入你想用的语言的代码:\nIf you need, please enter the code of the language you want to use:') | |
| if language.lower() in ["auto", "zh_cn", "en_us", "ja_jp", "ko_kr", "sv_se", "ru_ru", "vi_vn"]: | |
| i18n.change_language(language) | |
| else: | |
| print( | |
| "你没有输入有效的语言代码,将使用默认语言中文(zh_CN)\nYou did not enter a valid language code, the default language Chinese(zh_CN) will be used.") | |
| print(i18n("正在进行首次设置,请按照提示进行配置,配置将会被保存在") | |
| + " config.json " | |
| + i18n("中。") | |
| ) | |
| print( | |
| i18n("在") | |
| + " example_config.json " | |
| + i18n("中,包含了可用设置项及其简要说明。") | |
| ) | |
| print(i18n("现在开始进行交互式配置。碰到不知道该怎么办的设置项时,请直接按回车键跳过,程序会自动选择合适的默认值。") | |
| ) | |
| def set(self, config_items: List[ConfigItem], prompt: str): | |
| """Ask for a settings key | |
| Returns: | |
| Bool: Set or aborted | |
| """ | |
| print(colorama.Fore.YELLOW + i18n(prompt) + colorama.Style.RESET_ALL) | |
| choice = input(i18n("输入 Yes(y) 或 No(n),默认No:")) | |
| if choice.lower() in ["y", "yes"]: | |
| for config_item in config_items: | |
| if config_item.type == ConfigType.Password: | |
| config_value = getpass.getpass(generate_prompt_string(config_item)) | |
| print( | |
| colorama.Fore.CYAN | |
| + i18n(config_item.name) | |
| + colorama.Style.RESET_ALL | |
| + ": " | |
| + hide_middle_chars(config_value) | |
| ) | |
| self.config[config_item.key] = config_value | |
| elif config_item.type == ConfigType.String: | |
| config_value = input(generate_prompt_string(config_item)) | |
| print(generate_result_string(config_item, config_value)) | |
| self.config[config_item.key] = config_value | |
| elif config_item.type == ConfigType.Number: | |
| config_value = input(generate_prompt_string(config_item)) | |
| print(generate_result_string(config_item, config_value)) | |
| try: | |
| self.config[config_item.key] = int(config_value) | |
| except: | |
| print("输入的不是数字,将使用默认值。") | |
| elif config_item.type == ConfigType.ListOfStrings: | |
| # read one string at a time | |
| config_value = [] | |
| while True: | |
| config_value_item = input( | |
| generate_prompt_string(config_item) + i18n(",输入空行结束:") | |
| ) | |
| if config_value_item == "": | |
| break | |
| config_value.append(config_value_item) | |
| print(generate_result_string(config_item, ", ".join(config_value))) | |
| self.config[config_item.key] = config_value | |
| elif config_item.type == ConfigType.Bool: | |
| self.config[config_item.key] = True | |
| return True | |
| elif choice.lower() in ["n", "no"]: | |
| for config_item in config_items: | |
| print( | |
| i18n("你选择了不设置 ") | |
| + colorama.Fore.RED | |
| + i18n(config_item.name) | |
| + colorama.Style.RESET_ALL | |
| + i18n("。") | |
| ) | |
| if config_item.default is not None: | |
| self.config[config_item.key] = config_item.default | |
| if type == ConfigType.Bool: | |
| return True | |
| return False | |
| def set_users(self): | |
| # 询问设置用户账户 | |
| choice = input(colorama.Fore.YELLOW + i18n( | |
| "是否设置用户账户?设置后,用户需要登陆才可访问。输入 Yes(y) 或 No(n),默认No:") + colorama.Style.RESET_ALL) | |
| if choice.lower() in ["y", "yes"]: | |
| users = [] | |
| while True: | |
| username = input(i18n("请先输入用户名,输入空行结束添加用户:")) | |
| if username == "": | |
| break | |
| password = getpass.getpass(i18n("请输入密码:")) | |
| users.append([username, password]) | |
| self.config["users"] = users | |
| return True | |
| else: | |
| print(i18n("你选择了不设置用户账户。")) | |
| return False | |
| def __setitem__(self, setting_key: str, value): | |
| self.config[setting_key] = value | |
| def __getitem__(self, setting_key: str): | |
| return self.config[setting_key] | |
| def save(self): | |
| with open(self.file_path, "w", encoding="utf-8") as f: | |
| json.dump(self.config, f, ensure_ascii=False, indent=4) | |
| def setup_wizard(): | |
| if not os.path.exists(config_file): | |
| wizard = SetupWizard() | |
| flag = False | |
| # 设置openai_api_key。 | |
| flag = wizard.set( | |
| [ConfigItem("openai_api_key", "OpenAI API Key", type=ConfigType.Password)], | |
| "是否设置默认 OpenAI API Key?如果设置,软件启动时会自动加载该API Key,无需在 UI 中手动输入。如果不设置,可以在软件启动后手动输入 API Key。", | |
| ) | |
| if not flag: | |
| flag = wizard.set( | |
| [ | |
| ConfigItem( | |
| "openai_api_key", "OpenAI API Key", type=ConfigType.Password | |
| ) | |
| ], | |
| "如果不设置,将无法使用GPT模型和知识库在线索引功能。如果不设置此选项,您必须每次手动输入API Key。如果不设置,将自动启用本地编制索引的功能,可与本地模型配合使用。请问要设置默认 OpenAI API Key 吗?", | |
| ) | |
| if not flag: | |
| wizard["local_embedding"] = True | |
| # 设置openai_api_base | |
| wizard.set( | |
| [ConfigItem("openai_api_base", "OpenAI API Base", type=ConfigType.String)], | |
| "是否设置默认 OpenAI API Base?如果你在使用第三方API或者CloudFlare Workers等来中转OpenAI API,可以在这里设置。", | |
| ) | |
| # 设置http_proxy | |
| flag = wizard.set( | |
| [ConfigItem("http_proxy", "HTTP 代理", type=ConfigType.String)], | |
| "是否设置默认 HTTP 代理?这可以透过代理使用OpenAI API。", | |
| ) | |
| if flag: | |
| wizard["https_proxy"] = wizard["http_proxy"] | |
| # 设置多 API Key 切换 | |
| flag = wizard.set( | |
| [ConfigItem("api_key_list", "API Key 列表", type=ConfigType.ListOfStrings)], | |
| "是否设置多 API Key 切换?如果设置,将在多个API Key之间切换使用。", | |
| ) | |
| if flag: | |
| wizard["multi_api_key"] = True | |
| # 设置local_embedding | |
| wizard.set( | |
| [ConfigItem("local_embedding", "本地编制索引", type=ConfigType.Bool)], | |
| "是否在本地编制知识库索引?如果是,可以在使用本地模型时离线使用知识库,否则使用OpenAI服务来编制索引(需要OpenAI API Key)。请确保你的电脑有至少16GB内存。本地索引模型需要从互联网下载。", | |
| ) | |
| print( | |
| colorama.Back.GREEN + i18n("现在开始设置其他在线模型的API Key") + colorama.Style.RESET_ALL | |
| ) | |
| # Google Palm | |
| wizard.set( | |
| [ | |
| ConfigItem( | |
| "google_palm_api_key", | |
| "Google Palm API Key", | |
| type=ConfigType.Password, | |
| ) | |
| ], | |
| "是否设置默认 Google Palm API 密钥?如果设置,软件启动时会自动加载该API Key,无需在 UI 中手动输入。如果不设置,可以在软件启动后手动输入 API Key。", | |
| ) | |
| # XMChat | |
| wizard.set( | |
| [ConfigItem("xmchat_api_key", "XMChat API Key", type=ConfigType.Password)], | |
| "是否设置默认 XMChat API 密钥?如果设置,软件启动时会自动加载该API Key,无需在 UI 中手动输入。如果不设置,可以在软件启动后手动输入 API Key。", | |
| ) | |
| # MiniMax | |
| wizard.set( | |
| [ | |
| ConfigItem( | |
| "minimax_api_key", "MiniMax API Key", type=ConfigType.Password | |
| ), | |
| ConfigItem( | |
| "minimax_group_id", "MiniMax Group ID", type=ConfigType.Password | |
| ), | |
| ], | |
| "是否设置默认 MiniMax API 密钥和 Group ID?如果设置,软件启动时会自动加载该API Key,无需在 UI 中手动输入。如果不设置,将无法使用 MiniMax 模型。", | |
| ) | |
| # Midjourney | |
| wizard.set( | |
| [ | |
| ConfigItem( | |
| "midjourney_proxy_api_base", | |
| i18n("你的") + "https://github.com/novicezk/midjourney-proxy" + i18n("代理地址"), | |
| type=ConfigType.String, | |
| ), | |
| ConfigItem( | |
| "midjourney_proxy_api_secret", | |
| "MidJourney Proxy API Secret(用于鉴权访问 api,可选)", | |
| type=ConfigType.Password, | |
| ), | |
| ConfigItem( | |
| "midjourney_discord_proxy_url", | |
| "MidJourney Discord Proxy URL(用于对生成对图进行反代,可选)", | |
| type=ConfigType.String, | |
| ), | |
| ConfigItem( | |
| "midjourney_temp_folder", | |
| "你的 MidJourney 临时文件夹,用于存放生成的图片,填空则关闭自动下载切图(直接显示MJ的四宫格图)", | |
| type=ConfigType.String, | |
| default="files", | |
| ), | |
| ], | |
| "是否设置 Midjourney ?如果设置,软件启动时会自动加载该API Key,无需在 UI 中手动输入。如果不设置,将无法使用 Midjourney 模型。", | |
| ) | |
| # Spark | |
| wizard.set( | |
| [ | |
| ConfigItem("spark_appid", "讯飞星火 App ID", type=ConfigType.Password), | |
| ConfigItem( | |
| "spark_api_secret", "讯飞星火 API Secret", type=ConfigType.Password | |
| ), | |
| ConfigItem("spark_api_key", "讯飞星火 API Key", type=ConfigType.Password), | |
| ], | |
| "是否设置讯飞星火?如果设置,软件启动时会自动加载该API Key,无需在 UI 中手动输入。如果不设置,将无法使用 讯飞星火 模型。请注意不要搞混App ID和API Secret。", | |
| ) | |
| # Cloude | |
| wizard.set( | |
| [ | |
| ConfigItem( | |
| "cloude_api_secret", "Cloude API Secret", type=ConfigType.Password | |
| ), | |
| ], | |
| "是否设置Cloude API?如果设置,软件启动时会自动加载该API Key,无需在 UI 中手动输入。如果不设置,将无法使用 Cloude 模型。", | |
| ) | |
| # 文心一言 | |
| wizard.set( | |
| [ | |
| ConfigItem( | |
| "ernie_api_key", "百度云中的文心一言 API Key", type=ConfigType.Password | |
| ), | |
| ConfigItem( | |
| "ernie_secret_key", "百度云中的文心一言 Secret Key", type=ConfigType.Password | |
| ), | |
| ], | |
| "是否设置文心一言?如果设置,软件启动时会自动加载该API Key,无需在 UI 中手动输入。如果不设置,将无法使用 文心一言 模型。", | |
| ) | |
| # Azure OpenAI | |
| wizard.set( | |
| [ | |
| ConfigItem( | |
| "azure_openai_api_key", | |
| "Azure OpenAI API Key", | |
| type=ConfigType.Password, | |
| ), | |
| ConfigItem( | |
| "azure_openai_api_base_url", | |
| "Azure OpenAI API Base URL", | |
| type=ConfigType.String, | |
| ), | |
| ConfigItem( | |
| "azure_openai_api_version", | |
| "Azure OpenAI API Version", | |
| type=ConfigType.String, | |
| ), | |
| ConfigItem( | |
| "azure_deployment_name", | |
| "Azure OpenAI Chat 模型 Deployment 名称", | |
| type=ConfigType.String, | |
| ), | |
| ConfigItem( | |
| "azure_embedding_deployment_name", | |
| "Azure OpenAI Embedding 模型 Deployment 名称", | |
| type=ConfigType.String, | |
| ), | |
| ConfigItem( | |
| "azure_embedding_model_name", | |
| "Azure OpenAI Embedding 模型名称", | |
| type=ConfigType.String, | |
| ), | |
| ], | |
| "是否设置 Azure OpenAI?如果设置,软件启动时会自动加载该API Key,无需在 UI 中手动输入。如果不设置,将无法使用 Azure OpenAI 模型。", | |
| ) | |
| print( | |
| colorama.Back.GREEN + i18n("现在开始进行软件功能设置") + colorama.Style.RESET_ALL | |
| ) | |
| # 用户列表 | |
| wizard.set_users() | |
| # 未登录情况下是否不展示对话历史 | |
| wizard.set( | |
| [ | |
| ConfigItem( | |
| "hide_history_when_not_logged_in", | |
| "未登录情况下是否不展示对话历史", | |
| type=ConfigType.Bool, | |
| ) | |
| ], | |
| "是否设置未登录情况下是否不展示对话历史?如果设置,未登录情况下将不展示对话历史。", | |
| ) | |
| # 默认模型 | |
| wizard.set( | |
| [ | |
| ConfigItem( | |
| "default_model", | |
| "默认模型", | |
| type=ConfigType.String, | |
| default="gpt-3.5-turbo", | |
| ) | |
| ], | |
| "是否更改默认模型?如果设置,软件启动时会自动加载该模型,无需在 UI 中手动选择。目前的默认模型为 GPT3.5 Turbo。可选的在线模型有:" | |
| + "\n" | |
| + "\n".join(ONLINE_MODELS) | |
| + "\n" | |
| + "可选的本地模型为:" | |
| + "\n" | |
| + "\n".join(LOCAL_MODELS), | |
| ) | |
| # 是否启用自动加载 | |
| wizard.set( | |
| [ | |
| ConfigItem( | |
| "hide_history_when_not_logged_in", | |
| "是否不展示对话历史", | |
| type=ConfigType.Bool, | |
| default=False, | |
| ) | |
| ], | |
| "未设置用户名/密码情况下是否不展示对话历史?", | |
| ) | |
| # 如何自动命名对话历史 | |
| wizard.set( | |
| [ | |
| ConfigItem( | |
| "chat_name_method_index", | |
| "自动命名对话历史的方式(0: 使用日期时间命名;1: 使用第一条提问命名,2: 使用模型自动总结。)", | |
| type=ConfigType.Number, | |
| default=2, | |
| ) | |
| ], | |
| "是否选择自动命名对话历史的方式?", | |
| ) | |
| # 头像 | |
| wizard.set( | |
| [ | |
| ConfigItem( | |
| "bot_avatar", | |
| "机器人头像", | |
| type=ConfigType.String, | |
| default="default", | |
| ), | |
| ConfigItem( | |
| "user_avatar", | |
| "用户头像", | |
| type=ConfigType.String, | |
| default="default", | |
| ), | |
| ], | |
| '是否设置机器人头像和用户头像?可填写本地或网络图片链接,或者"none"(不显示头像)。', | |
| ) | |
| # 川虎助理 | |
| wizard.set( | |
| [ | |
| ConfigItem( | |
| "default_chuanhu_assistant_model", | |
| "川虎助理使用的模型", | |
| type=ConfigType.String, | |
| default="gpt-4", | |
| ), | |
| ConfigItem( | |
| "GOOGLE_CSE_ID", | |
| "谷歌搜索引擎ID(获取方式请看 https://stackoverflow.com/questions/37083058/programmatically-searching-google-in-python-using-custom-search)", | |
| type=ConfigType.String, | |
| ), | |
| ConfigItem( | |
| "GOOGLE_API_KEY", | |
| "谷歌API Key(获取方式请看 https://stackoverflow.com/questions/37083058/programmatically-searching-google-in-python-using-custom-search)", | |
| type=ConfigType.String, | |
| ), | |
| ConfigItem( | |
| "WOLFRAM_ALPHA_APPID", | |
| "Wolfram Alpha API Key(获取方式请看 https://products.wolframalpha.com/api/)", | |
| type=ConfigType.String, | |
| ), | |
| ConfigItem( | |
| "SERPAPI_API_KEY", | |
| "SerpAPI API Key(获取方式请看 https://serpapi.com/)", | |
| type=ConfigType.String, | |
| ), | |
| ], | |
| "是否设置川虎助理?如果不设置,仍可设置川虎助理。如果设置,可以使用川虎助理Pro模式。", | |
| ) | |
| # 文档处理与显示 | |
| wizard.set( | |
| [ | |
| ConfigItem( | |
| "latex_option", | |
| "LaTeX 公式渲染策略", | |
| type=ConfigType.String, | |
| default="default", | |
| ) | |
| ], | |
| '是否设置文档处理与显示?可选的 LaTeX 公式渲染策略有:"default", "strict", "all"或者"disabled"。', | |
| ) | |
| # 是否隐藏API Key输入框 | |
| wizard.set( | |
| [ | |
| ConfigItem( | |
| "hide_my_key", | |
| "是否隐藏API Key输入框", | |
| type=ConfigType.Bool, | |
| default=False, | |
| ) | |
| ], | |
| "是否隐藏API Key输入框?如果设置,将不会在 UI 中显示API Key输入框。", | |
| ) | |
| # 是否指定可用模型列表 | |
| wizard.set( | |
| [ | |
| ConfigItem( | |
| "available_models", | |
| "可用模型列表", | |
| type=ConfigType.ListOfStrings, | |
| ) | |
| ], | |
| "是否指定可用模型列表?如果设置,将只会在 UI 中显示指定的模型。默认展示所有模型。可用的模型有:" | |
| + "\n".join(ONLINE_MODELS) | |
| + "\n".join(LOCAL_MODELS), | |
| ) | |
| # 添加模型到列表 | |
| wizard.set( | |
| [ | |
| ConfigItem( | |
| "extra_models", | |
| "额外模型列表", | |
| type=ConfigType.ListOfStrings, | |
| ) | |
| ], | |
| "是否添加模型到列表?例如,训练好的GPT模型可以添加到列表中。可以在UI中自动添加模型到列表。", | |
| ) | |
| # 分享 | |
| wizard.set( | |
| [ | |
| ConfigItem( | |
| "server_name", | |
| "服务器地址,例如设置为 0.0.0.0 则可以通过公网访问(如果你用公网IP)", | |
| type=ConfigType.String, | |
| ), | |
| ConfigItem( | |
| "server_port", | |
| "服务器端口", | |
| type=ConfigType.Number, | |
| default=7860, | |
| ), | |
| ], | |
| "是否配置运行地址和端口?(不建议设置)", | |
| ) | |
| wizard.set( | |
| [ | |
| ConfigItem( | |
| "share", | |
| "是否通过gradio分享?", | |
| type=ConfigType.Bool, | |
| default=False, | |
| ) | |
| ], | |
| "是否通过gradio分享?可以通过公网访问。", | |
| ) | |
| wizard.save() | |
| print(colorama.Back.GREEN + i18n("设置完成。现在请重启本程序。") + colorama.Style.RESET_ALL) | |
| exit() | |
| def save_pkl(data, file_path): | |
| with open(file_path, 'wb') as f: | |
| pickle.dump(data, f) | |
| def load_pkl(file_path): | |
| with open(file_path, 'rb') as f: | |
| data = pickle.load(f) | |
| return data | |
| def chinese_preprocessing_func(text: str) -> List[str]: | |
| import jieba | |
| jieba.setLogLevel('ERROR') | |
| return jieba.lcut(text) | |