-
Notifications
You must be signed in to change notification settings - Fork 0
/
chatdemo.py
138 lines (113 loc) · 4.14 KB
/
chatdemo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os
from typing import Optional, Tuple
import requests
import gradio as gr
from langchain.chains import ConversationChain
from langchain.llms import OpenAI
from threading import Lock
import os
import getpass
from prompter_245 import prompt
#os.environ["OPENAI_API_KEY"] = getpass.getpass("")
from langchain.agents import Tool, AgentType, initialize_agent
from langchain.memory import ConversationBufferMemory
from langchain.chat_models import ChatOpenAI
from langchain.utilities import DuckDuckGoSearchAPIWrapper
from langchain.agents import AgentExecutor
from langchain import hub
from langchain.agents.format_scratchpad import format_log_to_str
from langchain.agents.output_parsers import ReActSingleInputOutputParser
from langchain.tools.render import render_text_description
import gradio as gr
import asyncio
from gradio_tools.tools.gradio_tool import GradioTool
def load_chain():
"""Logic for loading the chain you want to use should go here."""
llm = OpenAI(temperature=0)
chain = ConversationChain(llm=llm)
return chain
def set_openai_api_key(api_key: str):
"""Set the api key and return chain.
If no api_key, then None is returned.
"""
if api_key:
os.environ["OPENAI_API_KEY"] = "" #put api key here
chain = load_chain()
return chain
def query_fastapi(input_text):
# URL of the FastAPI endpoint
url = "http://localhost:8000/query/"
# Sending a POST request to the FastAPI server with the input
response = requests.post(url, json={"input": input_text})
# Extracting the output from the response
if response.status_code == 200:
output = response.json().get("output", "No output received")
else:
output = f"Error: {response.text}"
return output
class ChatWrapper:
def __init__(self):
self.lock = Lock()
def __call__(
self, api_key: str, inp: str, history: Optional[Tuple[str, str]], chain: Optional[ConversationChain]
):
"""Execute the chat functionality."""
self.lock.acquire()
try:
history = history or []
# If chain is None, that is because no API key was provided.
if chain is None:
history.append((inp, "Please paste your OpenAI key to use"))
return history, history
#Set OpenAI key
import openai
openai.api_key = api_key
# Run chain and append input.
output = query_fastapi(inp)
history.append((inp, output))
except Exception as e:
raise e
finally:
self.lock.release()
return history, history
chat = ChatWrapper()
block = gr.Blocks(css=".gradio-container {background-color: lightgray}")
with block:
with gr.Row():
gr.Markdown("<h1><center>DICE (Disaster Intervention Chatbot Engine)</center></h1>")
openai_api_key_textbox = gr.Textbox(
placeholder="Paste your OpenAI API key (sk-...)",
show_label=False,
lines=1,
type="password",
)
chatbot = gr.Chatbot()
with gr.Row():
message = gr.Textbox(
label="What's your question?",
placeholder="Ask any question related to natural disasters",
lines=1,
)
submit = gr.Button(value="Send", variant="secondary").style(full_width=False)
gr.Examples(
examples=[
"There's a wildfire what should i do?",
"Can you help me with disaster recovery centers near me?",
"What are some things I should keep in mind in case of a wildfire?",
],
inputs=message,
)
gr.HTML("DICE AGI.")
gr.HTML(
"<center>Powered by AI TOOLS</center>"
)
state = gr.State()
agent_state = gr.State()
submit.click(chat, inputs=[openai_api_key_textbox, message, state, agent_state], outputs=[chatbot, state])
message.submit(chat, inputs=[openai_api_key_textbox, message, state, agent_state], outputs=[chatbot, state])
openai_api_key_textbox.change(
set_openai_api_key,
inputs=[openai_api_key_textbox],
outputs=[agent_state],
)
block.launch(debug=True)