Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve: Research Agent Regex Parsing #469

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,4 +160,10 @@ cython_debug/
.idea/

notes.md
data/

# Local installation
data/
ui/

# Vscode
.vscode/
41 changes: 29 additions & 12 deletions devika.py
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ONLY BETTER FORMATTING : NO MODIFICATIONS

Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
DO NOT REARRANGE THE ORDER OF THE FUNCTION CALLS AND VARIABLE DECLARATIONS
AS IT MAY CAUSE IMPORT ERRORS AND OTHER ISSUES
"""

from gevent import monkey

monkey.patch_all()
from src.init import init_devika

init_devika()


Expand All @@ -21,7 +24,7 @@
from src.logger import Logger, route_logger
from src.project import ProjectManager
from src.state import AgentState
from src.agents import Agent
from src.agents import Agents
from src.llm import LLM


Expand All @@ -46,7 +49,7 @@


# initial socket
@socketio.on('socket_connect')
@socketio.on("socket_connect")
def test_connect(data):
print("Socket connected :: ", data)
emit_agent("socket_response", {"data": "Server Connected"})
Expand All @@ -58,7 +61,9 @@ def data():
project = manager.get_project_list()
models = LLM().list_models()
search_engines = ["Bing", "Google", "DuckDuckGo"]
return jsonify({"projects": project, "models": models, "search_engines": search_engines})
return jsonify(
{"projects": project, "models": models, "search_engines": search_engines}
)


@app.route("/api/messages", methods=["POST"])
Expand All @@ -70,36 +75,47 @@ def get_messages():


# Main socket
@socketio.on('user-message')
@socketio.on("user-message")
def handle_message(data):
logger.info(f"User message: {data}")
message = data.get('message')
base_model = data.get('base_model')
project_name = data.get('project_name')
search_engine = data.get('search_engine').lower()
message = data.get("message")
base_model = data.get("base_model")
project_name = data.get("project_name")
search_engine = data.get("search_engine").lower()

agent = Agent(base_model=base_model, search_engine=search_engine)
agent = Agents(base_model=base_model, search_engine=search_engine)

state = AgentState.get_latest_state(project_name)
if not state:
thread = Thread(target=lambda: agent.execute(message, project_name))
thread.start()
else:
if AgentState.is_agent_completed(project_name):
thread = Thread(target=lambda: agent.subsequent_execute(message, project_name))
thread = Thread(
target=lambda: agent.subsequent_execute(message, project_name)
)
thread.start()
else:
emit_agent("info", {"type": "warning", "message": "previous agent doesn't completed it's task."})
emit_agent(
"info",
{
"type": "warning",
"message": "previous agent doesn't completed it's task.",
},
)
last_state = AgentState.get_latest_state(project_name)
if last_state["agent_is_active"] or not last_state["completed"]:
# emit_agent("info", {"type": "info", "message": "I'm trying to complete the previous task again."})
# message = manager.get_latest_message_from_user(project_name)
thread = Thread(target=lambda: agent.execute(message, project_name))
thread.start()
else:
thread = Thread(target=lambda: agent.subsequent_execute(message, project_name))
thread = Thread(
target=lambda: agent.subsequent_execute(message, project_name)
)
thread.start()


@app.route("/api/is-agent-active", methods=["POST"])
@route_logger(logger)
def is_agent_active():
Expand Down Expand Up @@ -203,6 +219,7 @@ def get_settings():
def status():
return jsonify({"status": "server is running!"}), 200


if __name__ == "__main__":
logger.info("Devika is up and running!")
socketio.run(app, debug=False, port=1337, host="0.0.0.0")
5 changes: 3 additions & 2 deletions src/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .agent import Agent
from .agents import Agents
from .agent_template import AgentTemplate

from .planner import Planner
from .internal_monologue import InternalMonologue
from .researcher import Researcher
from .formatter import Formatter
from .coder import Coder
from .action import Action
from .runner import Runner
from .runner import Runner
135 changes: 135 additions & 0 deletions src/agents/agent_template.py
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

THE MAIN MODIFICATION

The parent class of all agents uses regex to efficiently parse LLM responses

Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import json
import re
import os
import inspect

from jinja2 import BaseLoader, Environment
from src.logger import Logger

logger = Logger()


class AgentTemplate:
""" "
This class is the parent class of all the agents. It defines the methods and attributes common to all agents.
"""

def __init__(self):
pass

def render(self, *args, **kwargs) -> str:
"""
This method renders the prompt template of the child class with the provided arguments.

Args:
*args: SHOULD NOT BE USED.
**kwargs: The arguments to provide to the prompt template.

Returns:
str: The rendered prompt.
"""
# Display a warning if args are provided
if args:
logger.warning(
f"INTERNAL ERROR : The render method of the AgentTemplate class should only be called with kwargs. Received args: {args}"
)

# Load the prompt template of the child class
env = Environment(loader=BaseLoader())
template_path = os.path.join(
os.path.dirname(inspect.getfile(self.__class__)), "prompt.jinja2"
)
with open(template_path, "r") as f:
template_string = f.read()
template = env.from_string(template_string)

# Check if all the variables in the template are provided
required_variables = re.findall(r"{{ (.*?) }}", template_string)
for variable in required_variables:
if variable not in kwargs:
raise ValueError(f"Missing variable {variable} in the render method.")

return template.render(**kwargs)

def parse_answer(self, response: str) -> dict | bool:
"""
This method try to parse the response from the model to a dict based on the prompt structure.
If it fails, it returns False.

Args:
response (str): The raw response from the model.

Returns:
dict | bool: The parsed response or False.
"""
try:
final_json = self.find_json_blocks(response)
except Exception as _:
return False

try:
final_json = json.loads(final_json)
except Exception as _:
return False

# Get required fields from the response
with open(
os.path.join(
os.path.dirname(inspect.getfile(self.__class__)), "prompt.jinja2"
),
"r",
) as f:
template = f.read()
template_json = self.find_json_blocks(template)
required_fields = list(json.loads(template_json).keys())

# Check if all the required fields are present in the response
for field in required_fields:
if field not in final_json:
logger.warning(f"Missing field {field} in the response.")
return False

return {field: final_json[field] for field in required_fields}

def find_json_blocks(self, text: str) -> str:
"""
This method extracts the JSON blocks from the text.

Args:
text (str): The text to extract the JSON blocks from.

Returns:
str: The extracted JSON blocks as a string that can be parsed using `json.loads`.

Raises:
Exception: If the JSON blocks cannot be parsed.
"""
# Remove eventually unrendered jinja2 blocks and extract JSON blocks
json_blocks = re.findall(r"{(.*?)}", re.sub(r"{{.*?}}", "", text), re.DOTALL)

try:
parsed_json_results = []
for block in json_blocks:
cleaned_block = block.replace("\n", "").replace("```\n\n```", "")
parsed_block = json.loads("{" + cleaned_block + "}")
parsed_json_results.append(parsed_block)
except Exception as e:
logger.warning(f"Error while parsing JSON blocks: {e}")
raise e

# Try to merge all the JSON blocks into a single JSON dict
try:
final_json = {}
for parsed_block in parsed_json_results:
for key, value in parsed_block.items():
if key in final_json:
final_json[key] += value
else:
final_json[key] = value

final_json = json.dumps(final_json, indent=4)
except Exception as e:
logger.warning(f"Error while merging JSON blocks: {e}")
raise e

return final_json