From fbde085b065dc1ac676fef7cc2006330a5c3e9e5 Mon Sep 17 00:00:00 2001 From: Edwin Yu Date: Wed, 11 Feb 2026 13:44:48 -0800 Subject: [PATCH] Reduce flakiness Signed-off-by: Edwin Yu --- .../test_short_term_memory.py | 23 ++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/tests/memmachine/episodic_memory/short_term_memory/test_short_term_memory.py b/tests/memmachine/episodic_memory/short_term_memory/test_short_term_memory.py index 36862f44b..983ffe1b8 100644 --- a/tests/memmachine/episodic_memory/short_term_memory/test_short_term_memory.py +++ b/tests/memmachine/episodic_memory/short_term_memory/test_short_term_memory.py @@ -1,7 +1,6 @@ import asyncio import re import string -import time import uuid from datetime import UTC, datetime from typing import Any, TypeVar, cast @@ -119,6 +118,9 @@ async def update_session_episodic_config( class MockLanguageModel(LanguageModel): """Mock implementation of LanguageModel for testing.""" + def __init__(self): + self.call_count = 0 + @staticmethod def parse_summary(text: str) -> str: m = re.search(r"summary:(\w+)", text) @@ -139,6 +141,7 @@ async def generate_response( raise ValueError("User prompt exceeds context window") if "model error" in prompt: raise RuntimeError("Simulated model error") + self.call_count += 1 await asyncio.sleep(0.1) user_input = self.parse_summary(prompt) return f"summary:{user_input}", "" @@ -270,20 +273,18 @@ async def test_create_delete_episodes(self, memory): assert episodes == [ep1, ep2, ep3] assert len(episodes) == 3 - async def test_summary_behavior(self, memory): + async def test_summary_behavior(self, memory, mock_model): chars = string.digits msgs = [char * 5 for char in chars] - start = time.time() summaries = [] for msg in msgs: ep = create_test_episode(content=msg) await memory.add_episodes([ep]) summaries.append(await memory.get_summary()) - duration = time.time() - start sorted_summaries = [s for s in summaries if s] expected = ["summary:01234567"] assert sorted_summaries == expected - assert 0.1 <= duration < 0.2 + assert mock_model.call_count == 1 async def test_keep_summary_if_model_error(self, memory): episodes = [create_test_episode(content="a" * 100)] @@ -293,17 +294,15 @@ async def test_keep_summary_if_model_error(self, memory): await memory.add_episodes(episodes) assert await memory.get_summary() == "summary:a" - async def test_get_will_wait_for_summary(self, memory): + async def test_get_will_wait_for_summary(self, memory, mock_model): memory._max_message_len = 20 chars = string.digits msgs = [char * 5 for char in chars] - start = time.time() summaries = set() for msg in msgs: ep = create_test_episode(content=msg) await memory.add_episodes([ep]) summaries.add(await memory.get_summary()) - duration = time.time() - start sorted_summary = sorted([s for s in summaries if s]) assert sorted_summary == [ "summary:01234", @@ -311,21 +310,19 @@ async def test_get_will_wait_for_summary(self, memory): "summary:012345678", "summary:0123456789", ] - assert 0.4 <= duration < 0.5 + assert mock_model.call_count == 4 @pytest.mark.asyncio - async def test_summary_exceed_context_window(self, memory): + async def test_summary_exceed_context_window(self, memory, mock_model): chars = string.digits msgs = [char * 2000 for char in chars] - start = time.time() for msg in msgs: ep = create_test_episode(content=msg) await memory.add_episodes([ep]) summary = await memory.get_summary() - duration = time.time() - start assert summary == "summary:" + string.digits # because of context window limit, summary is split into 4 calls - assert 0.4 <= duration < 0.5 + assert mock_model.call_count == 4 @pytest.mark.asyncio async def test_summary_catch_up(self, mock_model, mock_data_manager):