drmjh's picture
IRB demo
a4a3dbe
"""
General-Purpose LM Interview Interface
Author: Dr Musashi Hinck
Version Log:
- 2024.01.29: prototype without separate launching interface for demoing in SPIA class.
- Remove URL decoding
- Read sysprompt and initial_message from file
- Begins with user entering name/alias
- Azure OpenAI?
- 2024.01.31: wandb does not work for use case, what to do instead?
- Write to local file and then upload at end? (does filestream cause blocking?)
- 2024.03.03: Creating new instance for demoing to IRB
"""
from __future__ import annotations
import os
import logging
import json
import wandb
import gradio as gr
from typing import Generator, Any
from pathlib import Path
logger = logging.getLogger(__name__)
from utils import (
PromptTemplate,
convert_gradio_to_openai,
initialize_client,
seed_azure_key
)
# %% Initialization
CONFIG_DIR: Path = Path("./CogDebIRB")
if os.environ.get("AZURE_ENDPOINT") is None: # Set Azure credentials from local files
seed_azure_key()
client = initialize_client()
# %% (functions)
def load_config(
path: Path,
) -> tuple[str, str, dict[str, str | float], dict[str, str | list[str]]]:
"Read configs, return inital_message, system_message, model_args, wandb_args"
initial_message: str = (path / "initial_message.txt").read_text().strip()
system_message: str = (path / "system_message.txt").read_text().strip()
cfg: dict[str, str] = json.loads((path / "config.json").read_bytes())
model_args: dict[str, str | float] = cfg.get(
"model_args", {"model": "gpt4", "temperature": 0.0}
)
wandb_args: dict = cfg.get("wandb_args")
return initial_message, system_message, model_args, wandb_args
def initialize_interview(
initial_message: str,
) -> tuple[gr.Chatbot,
gr.Textbox,
gr.Button,
gr.Button,
gr.Button]:
"Read system prompt and start interview. Change visibilities of elements."
chat_history = [
[None, initial_message]
] # First item is for user, in this case bot starts interaction.
return (
gr.Chatbot(visible=True, value=chat_history), # chatDisplay
gr.Textbox(
placeholder="Type response here. Hit 'enter' to submit.",
visible=True,
interactive=True,
), # chatInput
gr.Button(visible=True, interactive=True), # chatSubmit
gr.Button(visible=False), # startInterview
gr.Button(visible=True), # resetButton
)
def initialize_tracker(
model_args: dict[str, str | float],
system_message: PromptTemplate,
userid: str,
wandb_args: dict[str, str | list[str]],
) -> gr.Textbox:
"Initializes wandb run for interview. Resets userBox afterwards."
run_config = model_args | {
"system_message": str(system_message),
"userid": userid,
}
logger.info(f"Initializing WandB run for {userid}")
wandb.init(
project=wandb_args["project"],
name=userid,
config=run_config,
tags=wandb_args["tags"],
)
return gr.Textbox(value=None, visible=False)
def save_interview(
chat_history: list[list[str | None]],
) -> None:
# Save chat_history as json
with open(CONFIG_DIR/"transcript.json", 'w') as fh:
json.dump(chat_history, fh, indent=2)
chat_data = []
for pair in chat_history:
for i, role in enumerate(["user", "bot"]):
if pair[i] is not None:
chat_data += [[role, pair[i]]]
chat_table = wandb.Table(data=chat_data, columns=["role", "message"])
logger.info("Uploading interview transcript to WandB...")
wandb.log({"chat_history": chat_table})
logger.info("Uploading complete.")
def user_message(
message: str, chat_history: list[list[str | None]]
) -> tuple[str, list[list[str | None]]]:
"Display user message immediately"
return "", chat_history + [[message, None]]
def bot_message(
chat_history: list[list[str | None]],
system_message: str,
model_args: dict[str, str | float],
) -> Generator[Any, Any, Any]:
# Prep messages
user_msg = chat_history[-1][0]
messages = convert_gradio_to_openai(chat_history[:-1])
messages = (
[{"role": "system", "content": system_message}]
+ messages
+ [{"role": "user", "content": user_msg}]
)
# API call
response = client.chat.completions.create(
messages=messages, stream=True, **model_args
)
# Streaming
chat_history[-1][1] = ""
for chunk in response:
delta = chunk.choices[0].delta.content
if delta:
chat_history[-1][1] += delta
yield chat_history
def reset_interview() -> (
tuple[
list[list[str | None]], gr.Chatbot, gr.Textbox, gr.Button, gr.Button, gr.Button
]
):
wandb.finish()
gr.Info("Interview reset.")
return (
gr.Chatbot(visible=False, value=[]), # chatDisplay
gr.Textbox(visible=False), # chatInput
gr.Button(visible=False), # chatSubmit
gr.Textbox(value=None, visible=True), # userBox
gr.Button(visible=True), # startInterview
gr.Button(visible=False), # resetButton
)
# LAYOUT
with gr.Blocks(theme="sudeepshouche/minimalist") as demo:
gr.Markdown("# Chat Interview Interface")
userDisplay = gr.Markdown("", visible=False)
# Config values
configDir = gr.State(value=CONFIG_DIR)
initialMessage = gr.Textbox(visible=False)
systemMessage = gr.Textbox(visible=False)
modelArgs = gr.State(value={"model": "", "temperature": ""})
wandbArgs = gr.State(value={"project": "", "tags": []})
## Start interview by entering name or alias
userBox = gr.Textbox(
value=None, placeholder="Enter name or alias and hit 'enter' to begin.", show_label=False
)
startInterview = gr.Button("Start Interview", variant="primary", visible=True)
## RESPONDENT
chatDisplay = gr.Chatbot(show_label=False, visible=False)
with gr.Row():
chatInput = gr.Textbox(
placeholder="Click 'Start Interview' to begin.",
visible=False,
interactive=False,
show_label=False,
scale=10,
)
chatSubmit = gr.Button(
"",
variant="primary",
interactive=False,
icon="./arrow_icon.svg",
visible=False,
)
resetButton = gr.Button("Save and Exit", visible=False, variant="stop")
disclaimer = gr.HTML(
"""
<div
style='font-size: 1em;
font-style: italic;
position: fixed;
left: 50%;
bottom: 20px;
transform: translate(-50%, -50%);
margin: 0 auto;
'
>{}</div>
""".format(
"Statements by the chatbot may contain factual inaccuracies."
)
)
## INTERACTIONS
# Start Interview button
userBox.change(lambda x: x, inputs=[userBox], outputs=[userDisplay], show_progress=False)
userBox.submit(
load_config,
inputs=configDir,
outputs=[initialMessage, systemMessage, modelArgs, wandbArgs],
).then(
initialize_interview,
inputs=[initialMessage],
outputs=[
chatDisplay,
chatInput,
chatSubmit,
startInterview,
resetButton,
],
).then(
initialize_tracker,
inputs=[modelArgs, systemMessage, userBox, wandbArgs],
outputs=[userBox]
)
startInterview.click(
load_config,
inputs=configDir,
outputs=[initialMessage, systemMessage, modelArgs, wandbArgs],
).then(
initialize_interview,
inputs=[initialMessage],
outputs=[
chatDisplay,
chatInput,
chatSubmit,
startInterview,
resetButton,
],
).then(
initialize_tracker,
inputs=[modelArgs, systemMessage, userBox, wandbArgs],
outputs=[userBox]
)
# Chat interaction
# "Enter"
chatInput.submit(
user_message,
inputs=[chatInput, chatDisplay],
outputs=[chatInput, chatDisplay],
queue=False,
).then(
bot_message,
inputs=[chatDisplay, systemMessage, modelArgs],
outputs=[chatDisplay],
).then(
save_interview, inputs=[chatDisplay]
)
# Button
chatSubmit.click(
user_message,
inputs=[chatInput, chatDisplay],
outputs=[chatInput, chatDisplay],
queue=False,
).then(
bot_message,
inputs=[chatDisplay, systemMessage, modelArgs],
outputs=[chatDisplay],
).then(
save_interview, inputs=[chatDisplay]
)
# Reset button
resetButton.click(save_interview, [chatDisplay]).then(
reset_interview,
outputs=[
chatDisplay,
chatInput,
chatSubmit,
userBox,
startInterview,
resetButton,
],
show_progress=False,
)
if __name__ == "__main__":
demo.launch()