Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import re
import string
import time
import uuid
from datetime import UTC, datetime
from typing import Any, TypeVar, cast
Expand Down Expand Up @@ -119,6 +118,9 @@ async def update_session_episodic_config(
class MockLanguageModel(LanguageModel):
"""Mock implementation of LanguageModel for testing."""

def __init__(self):
self.call_count = 0

@staticmethod
def parse_summary(text: str) -> str:
m = re.search(r"summary:(\w+)", text)
Expand All @@ -139,6 +141,7 @@ async def generate_response(
raise ValueError("User prompt exceeds context window")
if "model error" in prompt:
raise RuntimeError("Simulated model error")
self.call_count += 1
await asyncio.sleep(0.1)
user_input = self.parse_summary(prompt)
return f"summary:{user_input}", ""
Expand Down Expand Up @@ -270,20 +273,18 @@ async def test_create_delete_episodes(self, memory):
assert episodes == [ep1, ep2, ep3]
assert len(episodes) == 3

async def test_summary_behavior(self, memory):
async def test_summary_behavior(self, memory, mock_model):
chars = string.digits
msgs = [char * 5 for char in chars]
start = time.time()
summaries = []
for msg in msgs:
ep = create_test_episode(content=msg)
await memory.add_episodes([ep])
summaries.append(await memory.get_summary())
duration = time.time() - start
sorted_summaries = [s for s in summaries if s]
expected = ["summary:01234567"]
assert sorted_summaries == expected
assert 0.1 <= duration < 0.2
assert mock_model.call_count == 1

async def test_keep_summary_if_model_error(self, memory):
episodes = [create_test_episode(content="a" * 100)]
Expand All @@ -293,39 +294,35 @@ async def test_keep_summary_if_model_error(self, memory):
await memory.add_episodes(episodes)
assert await memory.get_summary() == "summary:a"

async def test_get_will_wait_for_summary(self, memory):
async def test_get_will_wait_for_summary(self, memory, mock_model):
memory._max_message_len = 20
chars = string.digits
msgs = [char * 5 for char in chars]
start = time.time()
summaries = set()
for msg in msgs:
ep = create_test_episode(content=msg)
await memory.add_episodes([ep])
summaries.add(await memory.get_summary())
duration = time.time() - start
sorted_summary = sorted([s for s in summaries if s])
assert sorted_summary == [
"summary:01234",
"summary:0123456",
"summary:012345678",
"summary:0123456789",
]
assert 0.4 <= duration < 0.5
assert mock_model.call_count == 4

@pytest.mark.asyncio
async def test_summary_exceed_context_window(self, memory):
async def test_summary_exceed_context_window(self, memory, mock_model):
chars = string.digits
msgs = [char * 2000 for char in chars]
start = time.time()
for msg in msgs:
ep = create_test_episode(content=msg)
await memory.add_episodes([ep])
summary = await memory.get_summary()
duration = time.time() - start
assert summary == "summary:" + string.digits
# because of context window limit, summary is split into 4 calls
assert 0.4 <= duration < 0.5
assert mock_model.call_count == 4

@pytest.mark.asyncio
async def test_summary_catch_up(self, mock_model, mock_data_manager):
Expand Down
Loading