Spaces:
Runtime error
Runtime error
| """ | |
| General-Purpose LM Interview Interface | |
| Author: Dr Musashi Hinck | |
| Version Log: | |
| - 2024.01.29: prototype without separate launching interface for demoing in SPIA class. | |
| - Remove URL decoding | |
| - Read sysprompt and initial_message from file | |
| - Begins with user entering name/alias | |
| - Azure OpenAI? | |
| - 2024.01.31: wandb does not work for use case, what to do instead? | |
| - Write to local file and then upload at end? (does filestream cause blocking?) | |
| - 2024.03.03: Creating new instance for demoing to IRB | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import logging | |
| import json | |
| import wandb | |
| import gradio as gr | |
| from typing import Generator, Any | |
| from pathlib import Path | |
| logger = logging.getLogger(__name__) | |
| from utils import ( | |
| PromptTemplate, | |
| convert_gradio_to_openai, | |
| initialize_client, | |
| seed_azure_key | |
| ) | |
| # %% Initialization | |
| CONFIG_DIR: Path = Path("./CogDebIRB") | |
| if os.environ.get("AZURE_ENDPOINT") is None: # Set Azure credentials from local files | |
| seed_azure_key() | |
| client = initialize_client() | |
| # %% (functions) | |
| def load_config( | |
| path: Path, | |
| ) -> tuple[str, str, dict[str, str | float], dict[str, str | list[str]]]: | |
| "Read configs, return inital_message, system_message, model_args, wandb_args" | |
| initial_message: str = (path / "initial_message.txt").read_text().strip() | |
| system_message: str = (path / "system_message.txt").read_text().strip() | |
| cfg: dict[str, str] = json.loads((path / "config.json").read_bytes()) | |
| model_args: dict[str, str | float] = cfg.get( | |
| "model_args", {"model": "gpt4", "temperature": 0.0} | |
| ) | |
| wandb_args: dict = cfg.get("wandb_args") | |
| return initial_message, system_message, model_args, wandb_args | |
| def initialize_interview( | |
| initial_message: str, | |
| ) -> tuple[gr.Chatbot, | |
| gr.Textbox, | |
| gr.Button, | |
| gr.Button, | |
| gr.Button]: | |
| "Read system prompt and start interview. Change visibilities of elements." | |
| chat_history = [ | |
| [None, initial_message] | |
| ] # First item is for user, in this case bot starts interaction. | |
| return ( | |
| gr.Chatbot(visible=True, value=chat_history), # chatDisplay | |
| gr.Textbox( | |
| placeholder="Type response here. Hit 'enter' to submit.", | |
| visible=True, | |
| interactive=True, | |
| ), # chatInput | |
| gr.Button(visible=True, interactive=True), # chatSubmit | |
| gr.Button(visible=False), # startInterview | |
| gr.Button(visible=True), # resetButton | |
| ) | |
| def initialize_tracker( | |
| model_args: dict[str, str | float], | |
| system_message: PromptTemplate, | |
| userid: str, | |
| wandb_args: dict[str, str | list[str]], | |
| ) -> gr.Textbox: | |
| "Initializes wandb run for interview. Resets userBox afterwards." | |
| run_config = model_args | { | |
| "system_message": str(system_message), | |
| "userid": userid, | |
| } | |
| logger.info(f"Initializing WandB run for {userid}") | |
| wandb.init( | |
| project=wandb_args["project"], | |
| name=userid, | |
| config=run_config, | |
| tags=wandb_args["tags"], | |
| ) | |
| return gr.Textbox(value=None, visible=False) | |
| def save_interview( | |
| chat_history: list[list[str | None]], | |
| ) -> None: | |
| # Save chat_history as json | |
| with open(CONFIG_DIR/"transcript.json", 'w') as fh: | |
| json.dump(chat_history, fh, indent=2) | |
| chat_data = [] | |
| for pair in chat_history: | |
| for i, role in enumerate(["user", "bot"]): | |
| if pair[i] is not None: | |
| chat_data += [[role, pair[i]]] | |
| chat_table = wandb.Table(data=chat_data, columns=["role", "message"]) | |
| logger.info("Uploading interview transcript to WandB...") | |
| wandb.log({"chat_history": chat_table}) | |
| logger.info("Uploading complete.") | |
| def user_message( | |
| message: str, chat_history: list[list[str | None]] | |
| ) -> tuple[str, list[list[str | None]]]: | |
| "Display user message immediately" | |
| return "", chat_history + [[message, None]] | |
| def bot_message( | |
| chat_history: list[list[str | None]], | |
| system_message: str, | |
| model_args: dict[str, str | float], | |
| ) -> Generator[Any, Any, Any]: | |
| # Prep messages | |
| user_msg = chat_history[-1][0] | |
| messages = convert_gradio_to_openai(chat_history[:-1]) | |
| messages = ( | |
| [{"role": "system", "content": system_message}] | |
| + messages | |
| + [{"role": "user", "content": user_msg}] | |
| ) | |
| # API call | |
| response = client.chat.completions.create( | |
| messages=messages, stream=True, **model_args | |
| ) | |
| # Streaming | |
| chat_history[-1][1] = "" | |
| for chunk in response: | |
| delta = chunk.choices[0].delta.content | |
| if delta: | |
| chat_history[-1][1] += delta | |
| yield chat_history | |
| def reset_interview() -> ( | |
| tuple[ | |
| list[list[str | None]], gr.Chatbot, gr.Textbox, gr.Button, gr.Button, gr.Button | |
| ] | |
| ): | |
| wandb.finish() | |
| gr.Info("Interview reset.") | |
| return ( | |
| gr.Chatbot(visible=False, value=[]), # chatDisplay | |
| gr.Textbox(visible=False), # chatInput | |
| gr.Button(visible=False), # chatSubmit | |
| gr.Textbox(value=None, visible=True), # userBox | |
| gr.Button(visible=True), # startInterview | |
| gr.Button(visible=False), # resetButton | |
| ) | |
| # LAYOUT | |
| with gr.Blocks(theme="sudeepshouche/minimalist") as demo: | |
| gr.Markdown("# Chat Interview Interface") | |
| userDisplay = gr.Markdown("", visible=False) | |
| # Config values | |
| configDir = gr.State(value=CONFIG_DIR) | |
| initialMessage = gr.Textbox(visible=False) | |
| systemMessage = gr.Textbox(visible=False) | |
| modelArgs = gr.State(value={"model": "", "temperature": ""}) | |
| wandbArgs = gr.State(value={"project": "", "tags": []}) | |
| ## Start interview by entering name or alias | |
| userBox = gr.Textbox( | |
| value=None, placeholder="Enter name or alias and hit 'enter' to begin.", show_label=False | |
| ) | |
| startInterview = gr.Button("Start Interview", variant="primary", visible=True) | |
| ## RESPONDENT | |
| chatDisplay = gr.Chatbot(show_label=False, visible=False) | |
| with gr.Row(): | |
| chatInput = gr.Textbox( | |
| placeholder="Click 'Start Interview' to begin.", | |
| visible=False, | |
| interactive=False, | |
| show_label=False, | |
| scale=10, | |
| ) | |
| chatSubmit = gr.Button( | |
| "", | |
| variant="primary", | |
| interactive=False, | |
| icon="./arrow_icon.svg", | |
| visible=False, | |
| ) | |
| resetButton = gr.Button("Save and Exit", visible=False, variant="stop") | |
| disclaimer = gr.HTML( | |
| """ | |
| <div | |
| style='font-size: 1em; | |
| font-style: italic; | |
| position: fixed; | |
| left: 50%; | |
| bottom: 20px; | |
| transform: translate(-50%, -50%); | |
| margin: 0 auto; | |
| ' | |
| >{}</div> | |
| """.format( | |
| "Statements by the chatbot may contain factual inaccuracies." | |
| ) | |
| ) | |
| ## INTERACTIONS | |
| # Start Interview button | |
| userBox.change(lambda x: x, inputs=[userBox], outputs=[userDisplay], show_progress=False) | |
| userBox.submit( | |
| load_config, | |
| inputs=configDir, | |
| outputs=[initialMessage, systemMessage, modelArgs, wandbArgs], | |
| ).then( | |
| initialize_interview, | |
| inputs=[initialMessage], | |
| outputs=[ | |
| chatDisplay, | |
| chatInput, | |
| chatSubmit, | |
| startInterview, | |
| resetButton, | |
| ], | |
| ).then( | |
| initialize_tracker, | |
| inputs=[modelArgs, systemMessage, userBox, wandbArgs], | |
| outputs=[userBox] | |
| ) | |
| startInterview.click( | |
| load_config, | |
| inputs=configDir, | |
| outputs=[initialMessage, systemMessage, modelArgs, wandbArgs], | |
| ).then( | |
| initialize_interview, | |
| inputs=[initialMessage], | |
| outputs=[ | |
| chatDisplay, | |
| chatInput, | |
| chatSubmit, | |
| startInterview, | |
| resetButton, | |
| ], | |
| ).then( | |
| initialize_tracker, | |
| inputs=[modelArgs, systemMessage, userBox, wandbArgs], | |
| outputs=[userBox] | |
| ) | |
| # Chat interaction | |
| # "Enter" | |
| chatInput.submit( | |
| user_message, | |
| inputs=[chatInput, chatDisplay], | |
| outputs=[chatInput, chatDisplay], | |
| queue=False, | |
| ).then( | |
| bot_message, | |
| inputs=[chatDisplay, systemMessage, modelArgs], | |
| outputs=[chatDisplay], | |
| ).then( | |
| save_interview, inputs=[chatDisplay] | |
| ) | |
| # Button | |
| chatSubmit.click( | |
| user_message, | |
| inputs=[chatInput, chatDisplay], | |
| outputs=[chatInput, chatDisplay], | |
| queue=False, | |
| ).then( | |
| bot_message, | |
| inputs=[chatDisplay, systemMessage, modelArgs], | |
| outputs=[chatDisplay], | |
| ).then( | |
| save_interview, inputs=[chatDisplay] | |
| ) | |
| # Reset button | |
| resetButton.click(save_interview, [chatDisplay]).then( | |
| reset_interview, | |
| outputs=[ | |
| chatDisplay, | |
| chatInput, | |
| chatSubmit, | |
| userBox, | |
| startInterview, | |
| resetButton, | |
| ], | |
| show_progress=False, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |