Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]Add Threading #86

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion erniebot-agent/erniebot_agent/memory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .base import Memory
from .base import Memory, MessageManager, PersistentMessageManager
from .limit_token_memory import LimitTokensMemory
from .sliding_window_memory import SlidingWindowMemory
from .whole_memory import WholeMemory
141 changes: 127 additions & 14 deletions erniebot-agent/erniebot_agent/memory/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List
from typing import Dict, List, Optional, Union

from erniebot_agent.messages import AIMessage, Message
from erniebot_agent.messages import AIMessage, HumanMessage, Message

# Test Cases

user_AK_relation = {"AK-123": "user-123", "AK-124": "user-124"}
user_session_id_relation: Dict[str, List] = {"user-123": [], "user-124": ["session-124", "session-125"]}

session_messages = {
"session-124": [HumanMessage(content="你好"), AIMessage(content="你好124", function_call=None)],
"session-125": [HumanMessage(content="你好"), AIMessage(content="你好125", function_call=None)],
}


class MessageManager:
Expand All @@ -38,15 +48,125 @@ def clear_messages(self) -> None:
def update_last_message_token_count(self, token_count: int):
self.messages[-1].token_count = token_count

def retrieve_messages(self) -> List[Message]:
def get_messages(self) -> List[Message]:
return self.messages


class RemoteMemory:
"""
远程memory的实现类, 用于管理一个user 在一个session中的messages。
"""

def __init__(self, session_id):
self.session_id: str = session_id
self.messages: list[Message] = session_messages[session_id]

def add_message(self, message: Message):
"make changes to the session's memory"
session_messages[self.session_id].append(message)

def pop_message(self):
"""pop the message from the start"""
session_messages[self.session_id].pop(0)

def clear_memory(self):
session_messages[self.session_id] = []

def get_messages(self):
if self.session_id not in session_messages.keys():
raise KeyError(f"session_id {self.session_id} not found")
return session_messages[self.session_id]

def search_memory(self, session_id, payload): # TODO: refer zep
pass

# TODO: 关闭之后同步message的变化到数据库


class MessageStorageServer: # 绑定user
"""
MessageStorageServer 用于管理一个user在多个session中的message切换。

Args:
request_url (str): 请求地址
AK (str): 用户ID
session_id (str, optional): 用户选择的session对应的session id. Defaults to None.
"""

def __init__(self, request_url: str, AK: str, session_id: Optional[str] = None):
self.request_url = request_url
self.AK = AK
self.user_id = user_AK_relation[AK]
self.sessions: List = user_session_id_relation[self.user_id]
if len(self.sessions) == 0:
self.create_session()

self.session_id = session_id if session_id else self.sessions[-1] # TODO: session选择
self.memory = RemoteMemory(self.session_id)

def get_messages(self):
return self.memory.get_messages()

def create_session(
self,
):
"""create a new session for user and return the session id"""
import uuid

session_id = uuid.uuid4().hex # A new session identifier
self.sessions.append(session_id)
user_session_id_relation[self.user_id] = [session_id]
global session_messages
session_messages[session_id] = []
# 同时在数据库中创建相应空间
return session_id


class PersistentMessageManager:
"""
PersistentMessageManager 用于本地的持久化、隔离化message管理。
"""

def __init__(self, url: str, AK: str, session_id: Optional[str] = None):
self.client = MessageStorageServer(
request_url=url, AK=AK, session_id=session_id
) # client 内确定了session_id
self.session_id = self.client.session_id # 统一内外的session_id
self.messages = self.get_messages()

def add_message(self, message: Message):
self.client.memory.add_message(message=message)

def clear_messages(self):
self.messages = []
self.client.memory.clear_memory()

def pop_message(self): # TODO: choose from pop_message and cherry_pick_message
delete_message = self.client.memory.pop_message()
return delete_message

def get_messages(
self,
) -> List[Message]: # system,AI,user,contains summary if necessary
memory = self.client.memory.get_messages()
return memory

# def cherry_pick_message(self, query): # TODO: 不使用pop,而是利用存储后端的索引功能找到相关message,但不保证限制长度
# from zep_python import MemorySearchPayload

# payload: MemorySearchPayload = MemorySearchPayload(text=query)

# return self.client.memory.search_memory(self.session_id, payload)

def update_last_message_token_count(self, token_count: int):
self.client.memory.get_messages()[-1].token_count = token_count


class Memory:
"""The base class of memory"""

def __init__(self):
self.msg_manager = MessageManager()
def __init__(self, message_manager: Union[PersistentMessageManager, MessageManager] = MessageManager()):
self.msg_manager = message_manager

def add_messages(self, messages: List[Message]):
for message in messages:
Expand All @@ -55,17 +175,10 @@ def add_messages(self, messages: List[Message]):
def add_message(self, message: Message):
if isinstance(message, AIMessage):
self.msg_manager.update_last_message_token_count(message.query_tokens_count)
self.msg_manager.add_message(message)
self.msg_manager.add_message(message=message)

def get_messages(self) -> List[Message]:
return self.msg_manager.retrieve_messages()
return self.msg_manager.get_messages()

def clear_chat_history(self):
self.msg_manager.clear_messages()


class WholeMemory(Memory):
"""The memory include all the messages"""

def __init__(self):
super().__init__()
6 changes: 3 additions & 3 deletions erniebot-agent/erniebot_agent/memory/limit_token_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.


from erniebot_agent.memory import Memory
from erniebot_agent.memory import Memory, MessageManager
from erniebot_agent.messages import AIMessage, Message


Expand All @@ -22,8 +22,8 @@ class LimitTokensMemory(Memory):
If tokens >= max_token_limit, pop message from memory.
"""

def __init__(self, max_token_limit=None):
super().__init__()
def __init__(self, max_token_limit=None, message_manager=MessageManager()):
super().__init__(message_manager)
self.max_token_limit = max_token_limit
self.mem_token_count = 0

Expand Down
6 changes: 3 additions & 3 deletions erniebot-agent/erniebot_agent/memory/sliding_window_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from erniebot_agent.memory import Memory
from erniebot_agent.memory import Memory, MessageManager
from erniebot_agent.messages import Message


class SlidingWindowMemory(Memory):
"""This class controls max number of messages."""

def __init__(self, max_num_message: int):
super().__init__()
def __init__(self, max_num_message: int, message_manager=MessageManager()):
super().__init__(message_manager)
self.max_num_message = max_num_message

assert (isinstance(max_num_message, int)) and (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@ def setUp(self):
async def test_limit_token_memory(self):
messages = HumanMessage(content="What is the purpose of model regularization?")

memory = LimitTokensMemory(4000)
memory = LimitTokensMemory(4)
memory.add_message(messages)
message = await self.llm.async_chat([messages])
memory.add_message(message)
memory.add_message(HumanMessage("OK, what else?"))
message = await self.llm.async_chat(memory.get_messages())
memory.add_message(message)
self.assertTrue(message is not None)
self.assertTrue(memory.mem_token_count <= 4)

@pytest.mark.asyncio
async def test_limit_token_memory_truncate_tokens(self, k=3): # truncate through returned message
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import asyncio
import unittest

import pytest
from erniebot_agent.memory import PersistentMessageManager, WholeMemory
from erniebot_agent.messages import HumanMessage

from tests.unit_tests.testing_utils import MockErnieBot


class TestSlidingWindowMemory(unittest.TestCase):
def setUp(self):
self.llm = MockErnieBot(None, None, None)

# @pytest.mark.asyncio
@pytest.mark.parametrize("k", [1, 2, 4, 5, 10])
def test_sliding_window_memory(self, k=3): # asyn pytest
async def test_sliding_window_memory(k=3): # asyn pytest
# The memory

memory = WholeMemory(
message_manager=PersistentMessageManager(AK="AK-123", url="not used", session_id=None)
)

for _ in range(k):
# 2 times of human message
memory.add_message(HumanMessage(content="What is the purpose of model regularization?"))
# AI message
message = await self.llm.async_chat(memory.get_messages())
memory.add_message(message)
print(
"!!! test_sliding_window_memory_wo_sessionid, conversation output",
memory.msg_manager.client.memory.messages,
)

self.assertTrue(len(memory.get_messages()) == 2 * k)

asyncio.run(test_sliding_window_memory(k))

@pytest.mark.parametrize("k", [1, 2, 4, 5, 10])
def test_sliding_window_memory_with_sessionid(self, k=3): # asyn pytest
async def test_sliding_window_memory(k=3): # asyn pytest
# The memory

memory = WholeMemory(
message_manager=PersistentMessageManager(
AK="AK-124", url="not used", session_id="session-124"
),
)

for _ in range(k):
# 2 times of human message
memory.add_message(HumanMessage(content="What is the purpose of model regularization?"))

# AI message
message = await self.llm.async_chat(memory.get_messages())
memory.add_message(message)
print(
"!!! test_sliding_window_memory_with_sessionid, conversation output",
memory.msg_manager.client.memory.messages,
)

self.assertTrue(len(memory.get_messages()) == 2 * k + 2)

asyncio.run(test_sliding_window_memory(k))


if __name__ == "__main__":
unittest.main()