Implement performance improvements for memory caching, HTTP client reuse, and regex optimization

Co-authored-by: Kenearos <86194771+Kenearos@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot] 2026-01-27 17:33:17 +00:00
parent 860e4d5027
commit b72cd9db1c
5 changed files with 451 additions and 246 deletions

View file

@ -30,12 +30,27 @@ class PerplexityProvider:
self.base_url = "https://api.perplexity.ai"
self.logger = logger or logging.getLogger(__name__)
# Reusable HTTP client (created lazily)
self._client: Optional[httpx.AsyncClient] = None
# Statistics
self.total_requests = 0
self.total_tokens = 0
self.total_errors = 0
self.last_response_time = 0
async def _get_client(self) -> httpx.AsyncClient:
"""Get or create the HTTP client (lazy initialization)"""
if self._client is None or self._client.is_closed:
self._client = httpx.AsyncClient(timeout=30.0)
return self._client
async def close(self):
"""Close the HTTP client"""
if self._client is not None and not self._client.is_closed:
await self._client.aclose()
self._client = None
async def get_response(self, messages):
"""
Send messages to Perplexity API and get response
@ -66,30 +81,30 @@ class PerplexityProvider:
start_time = time.time()
try:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
f"{self.base_url}/chat/completions",
json=payload,
headers=headers
)
client = await self._get_client()
response = await client.post(
f"{self.base_url}/chat/completions",
json=payload,
headers=headers
)
self.last_response_time = time.time() - start_time
self.last_response_time = time.time() - start_time
if response.status_code == 200:
data = response.json()
content = data['choices'][0]['message']['content']
tokens_used = data.get('usage', {}).get('total_tokens', 0)
if response.status_code == 200:
data = response.json()
content = data['choices'][0]['message']['content']
tokens_used = data.get('usage', {}).get('total_tokens', 0)
# Update statistics
self.total_requests += 1
self.total_tokens += tokens_used
# Update statistics
self.total_requests += 1
self.total_tokens += tokens_used
return content
else:
self.total_errors += 1
error_msg = f"API Error {response.status_code}: {response.text}"
self.logger.error(error_msg)
return None
return content
else:
self.total_errors += 1
error_msg = f"API Error {response.status_code}: {response.text}"
self.logger.error(error_msg)
return None
except httpx.TimeoutException:
self.total_errors += 1

106
memory.py
View file

@ -27,6 +27,8 @@ class ConversationMemory:
self.max_messages = max_messages
self.retention_hours = retention_hours
self.logger = logger or logging.getLogger(__name__)
# In-memory cache to reduce file I/O for frequently accessed users
self._cache: Dict[str, List[Dict]] = {}
def _get_user_file(self, username):
"""Get the file path for a user's conversation history"""
@ -34,6 +36,59 @@ class ConversationMemory:
safe_username = "".join(c for c in username.lower() if c.isalnum() or c in "._-")
return self.data_dir / f"{safe_username}.json"
def _load_user_history(self, username):
"""
Load user history from cache or file
Args:
username (str): Twitch username
Returns:
list: List of message dicts or empty list
"""
safe_username = username.lower()
# Check cache first
if safe_username in self._cache:
return self._cache[safe_username]
file_path = self._get_user_file(username)
if not file_path.exists():
self._cache[safe_username] = []
return []
try:
with open(file_path, 'r', encoding='utf-8') as f:
history = json.load(f)
self._cache[safe_username] = history
return history
except Exception as e:
self.logger.error(f"Error loading history for {username}: {e}")
self._cache[safe_username] = []
return []
def _save_user_history(self, username, history):
"""
Save user history to file and update cache
Args:
username (str): Twitch username
history (list): List of message dicts
"""
safe_username = username.lower()
file_path = self._get_user_file(username)
# Update cache
self._cache[safe_username] = history
# Save to file
try:
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(history, f, ensure_ascii=False, indent=2)
except Exception as e:
self.logger.error(f"Error saving history for {username}: {e}")
def get_user_history(self, username, limit=5):
"""
Load recent chat history for a user
@ -45,29 +100,22 @@ class ConversationMemory:
Returns:
list: List of message dicts with role, content, timestamp
"""
file_path = self._get_user_file(username)
history = self._load_user_history(username)
if not file_path.exists():
if not history:
return []
try:
with open(file_path, 'r', encoding='utf-8') as f:
history = json.load(f)
except Exception as e:
self.logger.error(f"Error loading history for {username}: {e}")
return []
# Filter by retention time
# Filter by retention time using list comprehension for better performance
cutoff_time = datetime.now() - timedelta(hours=self.retention_hours)
recent = []
for msg in history:
def is_recent(msg):
try:
msg_time = datetime.fromisoformat(msg['timestamp'])
if msg_time > cutoff_time:
recent.append(msg)
return msg_time > cutoff_time
except (KeyError, ValueError):
continue
return False
recent = [msg for msg in history if is_recent(msg)]
# Return only the most recent messages up to limit
return recent[-limit:] if recent else []
@ -81,17 +129,8 @@ class ConversationMemory:
role (str): 'user' or 'assistant'
content (str): Message content
"""
file_path = self._get_user_file(username)
# Load existing history
history = []
if file_path.exists():
try:
with open(file_path, 'r', encoding='utf-8') as f:
history = json.load(f)
except Exception as e:
self.logger.error(f"Error loading history for {username}: {e}")
history = []
# Load existing history (uses cache if available)
history = self._load_user_history(username)
# Add new message
history.append({
@ -104,12 +143,8 @@ class ConversationMemory:
if len(history) > self.max_messages:
history = history[-self.max_messages:]
# Save back to file
try:
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(history, f, ensure_ascii=False, indent=2)
except Exception as e:
self.logger.error(f"Error saving history for {username}: {e}")
# Save back to file and update cache
self._save_user_history(username, history)
def format_for_prompt(self, history):
"""
@ -136,6 +171,13 @@ class ConversationMemory:
Args:
username (str): Twitch username
"""
safe_username = username.lower()
# Clear from cache
if safe_username in self._cache:
del self._cache[safe_username]
# Clear from disk
file_path = self._get_user_file(username)
if file_path.exists():
try:

View file

@ -3,7 +3,7 @@ Tests for PerplexityProvider AI API class
"""
import pytest
import httpx
from unittest.mock import AsyncMock, Mock, patch
from unittest.mock import AsyncMock, Mock, patch, PropertyMock
from ai_provider import PerplexityProvider
@ -42,28 +42,36 @@ class TestPerplexityProvider:
assert provider.total_errors == 0
assert provider.last_response_time == 0
def test_init_client_starts_as_none(self):
"""Test that HTTP client starts as None (lazy initialization)"""
provider = PerplexityProvider(api_key="test-key")
assert provider._client is None
@pytest.mark.asyncio
async def test_get_response_success(self, sample_messages, mock_perplexity_response):
"""Test successful API response"""
provider = PerplexityProvider(api_key="test-key")
# Mock the HTTP client
with patch('httpx.AsyncClient') as mock_client:
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = mock_perplexity_response
# Create a mock client with is_closed property
mock_client = Mock()
mock_client.is_closed = False
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = mock_perplexity_response
mock_post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value.post = mock_post
mock_client.post = AsyncMock(return_value=mock_response)
# Call the method
result = await provider.get_response(sample_messages)
# Inject the mock client
provider._client = mock_client
# Verify result
assert result == "This is a test response from the AI."
assert provider.total_requests == 1
assert provider.total_tokens == 70 # From mock response
assert provider.total_errors == 0
# Call the method
result = await provider.get_response(sample_messages)
# Verify result
assert result == "This is a test response from the AI."
assert provider.total_requests == 1
assert provider.total_tokens == 70 # From mock response
assert provider.total_errors == 0
@pytest.mark.asyncio
async def test_get_response_constructs_correct_payload(self, sample_messages):
@ -75,130 +83,146 @@ class TestPerplexityProvider:
temperature=0.7
)
with patch('httpx.AsyncClient') as mock_client:
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [{"message": {"content": "test"}}],
"usage": {"total_tokens": 10}
}
mock_client = Mock()
mock_client.is_closed = False
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [{"message": {"content": "test"}}],
"usage": {"total_tokens": 10}
}
mock_post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value.post = mock_post
mock_post = AsyncMock(return_value=mock_response)
mock_client.post = mock_post
await provider.get_response(sample_messages)
provider._client = mock_client
# Verify the call was made with correct parameters
call_args = mock_post.call_args
assert call_args[1]["json"]["model"] == "sonar-pro"
assert call_args[1]["json"]["messages"] == sample_messages
assert call_args[1]["json"]["max_tokens"] == 450
assert call_args[1]["json"]["temperature"] == 0.7
await provider.get_response(sample_messages)
# Verify the call was made with correct parameters
call_args = mock_post.call_args
assert call_args[1]["json"]["model"] == "sonar-pro"
assert call_args[1]["json"]["messages"] == sample_messages
assert call_args[1]["json"]["max_tokens"] == 450
assert call_args[1]["json"]["temperature"] == 0.7
@pytest.mark.asyncio
async def test_get_response_includes_auth_header(self, sample_messages):
"""Test that Authorization header is included"""
provider = PerplexityProvider(api_key="test-secret-key")
with patch('httpx.AsyncClient') as mock_client:
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [{"message": {"content": "test"}}],
"usage": {"total_tokens": 10}
}
mock_client = Mock()
mock_client.is_closed = False
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [{"message": {"content": "test"}}],
"usage": {"total_tokens": 10}
}
mock_post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value.post = mock_post
mock_post = AsyncMock(return_value=mock_response)
mock_client.post = mock_post
await provider.get_response(sample_messages)
provider._client = mock_client
call_args = mock_post.call_args
headers = call_args[1]["headers"]
assert headers["Authorization"] == "Bearer test-secret-key"
assert headers["Content-Type"] == "application/json"
await provider.get_response(sample_messages)
call_args = mock_post.call_args
headers = call_args[1]["headers"]
assert headers["Authorization"] == "Bearer test-secret-key"
assert headers["Content-Type"] == "application/json"
@pytest.mark.asyncio
async def test_get_response_handles_401_error(self, sample_messages):
"""Test handling of 401 Unauthorized error"""
provider = PerplexityProvider(api_key="invalid-key")
with patch('httpx.AsyncClient') as mock_client:
mock_response = Mock()
mock_response.status_code = 401
mock_response.text = "Unauthorized"
mock_client = Mock()
mock_client.is_closed = False
mock_response = Mock()
mock_response.status_code = 401
mock_response.text = "Unauthorized"
mock_post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value.post = mock_post
mock_client.post = AsyncMock(return_value=mock_response)
result = await provider.get_response(sample_messages)
provider._client = mock_client
assert result is None
assert provider.total_errors == 1
assert provider.total_requests == 0 # Failed requests don't count
result = await provider.get_response(sample_messages)
assert result is None
assert provider.total_errors == 1
assert provider.total_requests == 0 # Failed requests don't count
@pytest.mark.asyncio
async def test_get_response_handles_500_error(self, sample_messages):
"""Test handling of 500 server error"""
provider = PerplexityProvider(api_key="test-key")
with patch('httpx.AsyncClient') as mock_client:
mock_response = Mock()
mock_response.status_code = 500
mock_response.text = "Internal Server Error"
mock_client = Mock()
mock_client.is_closed = False
mock_response = Mock()
mock_response.status_code = 500
mock_response.text = "Internal Server Error"
mock_post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value.post = mock_post
mock_client.post = AsyncMock(return_value=mock_response)
result = await provider.get_response(sample_messages)
provider._client = mock_client
assert result is None
assert provider.total_errors == 1
result = await provider.get_response(sample_messages)
assert result is None
assert provider.total_errors == 1
@pytest.mark.asyncio
async def test_get_response_handles_timeout(self, sample_messages):
"""Test handling of request timeout"""
provider = PerplexityProvider(api_key="test-key")
with patch('httpx.AsyncClient') as mock_client:
mock_post = AsyncMock(side_effect=httpx.TimeoutException("Timeout"))
mock_client.return_value.__aenter__.return_value.post = mock_post
mock_client = Mock()
mock_client.is_closed = False
mock_client.post = AsyncMock(side_effect=httpx.TimeoutException("Timeout"))
result = await provider.get_response(sample_messages)
provider._client = mock_client
assert result is None
assert provider.total_errors == 1
result = await provider.get_response(sample_messages)
assert result is None
assert provider.total_errors == 1
@pytest.mark.asyncio
async def test_get_response_handles_network_error(self, sample_messages):
"""Test handling of network errors"""
provider = PerplexityProvider(api_key="test-key")
with patch('httpx.AsyncClient') as mock_client:
mock_post = AsyncMock(side_effect=httpx.NetworkError("Network error"))
mock_client.return_value.__aenter__.return_value.post = mock_post
mock_client = Mock()
mock_client.is_closed = False
mock_client.post = AsyncMock(side_effect=httpx.NetworkError("Network error"))
result = await provider.get_response(sample_messages)
provider._client = mock_client
assert result is None
assert provider.total_errors == 1
result = await provider.get_response(sample_messages)
assert result is None
assert provider.total_errors == 1
@pytest.mark.asyncio
async def test_get_response_tracks_response_time(self, sample_messages, mock_perplexity_response):
"""Test that response time is tracked"""
provider = PerplexityProvider(api_key="test-key")
with patch('httpx.AsyncClient') as mock_client:
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = mock_perplexity_response
mock_client = Mock()
mock_client.is_closed = False
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = mock_perplexity_response
mock_post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value.post = mock_post
mock_client.post = AsyncMock(return_value=mock_response)
await provider.get_response(sample_messages)
provider._client = mock_client
assert provider.last_response_time > 0
await provider.get_response(sample_messages)
assert provider.last_response_time > 0
@pytest.mark.asyncio
async def test_validate_api_key_success(self):
@ -282,56 +306,61 @@ class TestPerplexityProvider:
"""Test statistics after successful requests"""
provider = PerplexityProvider(api_key="test-key")
with patch('httpx.AsyncClient') as mock_client:
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = mock_perplexity_response
mock_client = Mock()
mock_client.is_closed = False
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = mock_perplexity_response
mock_post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value.post = mock_post
mock_client.post = AsyncMock(return_value=mock_response)
# Make 2 successful requests
await provider.get_response(sample_messages)
await provider.get_response(sample_messages)
provider._client = mock_client
stats = provider.get_statistics()
# Make 2 successful requests
await provider.get_response(sample_messages)
await provider.get_response(sample_messages)
assert stats["total_requests"] == 2
assert stats["total_tokens"] == 140 # 70 * 2
assert stats["total_errors"] == 0
assert stats["success_rate"] == 100
assert stats["estimated_cost"] > 0
stats = provider.get_statistics()
assert stats["total_requests"] == 2
assert stats["total_tokens"] == 140 # 70 * 2
assert stats["total_errors"] == 0
assert stats["success_rate"] == 100
assert stats["estimated_cost"] > 0
@pytest.mark.asyncio
async def test_get_statistics_with_errors(self, sample_messages):
"""Test statistics calculation with errors"""
provider = PerplexityProvider(api_key="test-key")
with patch('httpx.AsyncClient') as mock_client:
# First request succeeds
mock_response_success = Mock()
mock_response_success.status_code = 200
mock_response_success.json.return_value = {
"choices": [{"message": {"content": "test"}}],
"usage": {"total_tokens": 50}
}
mock_client = Mock()
mock_client.is_closed = False
# Second request fails
mock_response_fail = Mock()
mock_response_fail.status_code = 500
mock_response_fail.text = "Error"
# First request succeeds
mock_response_success = Mock()
mock_response_success.status_code = 200
mock_response_success.json.return_value = {
"choices": [{"message": {"content": "test"}}],
"usage": {"total_tokens": 50}
}
mock_post = AsyncMock(side_effect=[mock_response_success, mock_response_fail])
mock_client.return_value.__aenter__.return_value.post = mock_post
# Second request fails
mock_response_fail = Mock()
mock_response_fail.status_code = 500
mock_response_fail.text = "Error"
await provider.get_response(sample_messages)
await provider.get_response(sample_messages)
mock_client.post = AsyncMock(side_effect=[mock_response_success, mock_response_fail])
stats = provider.get_statistics()
provider._client = mock_client
assert stats["total_requests"] == 1
assert stats["total_errors"] == 1
assert stats["success_rate"] == 0 # 0 out of 1 completed request succeeded
await provider.get_response(sample_messages)
await provider.get_response(sample_messages)
stats = provider.get_statistics()
assert stats["total_requests"] == 1
assert stats["total_errors"] == 1
assert stats["success_rate"] == 0 # 0 out of 1 completed request succeeded
def test_reset_statistics(self):
"""Test resetting statistics"""
@ -355,43 +384,48 @@ class TestPerplexityProvider:
"""Test handling response without usage data"""
provider = PerplexityProvider(api_key="test-key")
with patch('httpx.AsyncClient') as mock_client:
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [{"message": {"content": "test response"}}]
# No usage field
}
mock_client = Mock()
mock_client.is_closed = False
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [{"message": {"content": "test response"}}]
# No usage field
}
mock_post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value.post = mock_post
mock_client.post = AsyncMock(return_value=mock_response)
result = await provider.get_response(sample_messages)
provider._client = mock_client
assert result == "test response"
assert provider.total_tokens == 0 # Should default to 0
result = await provider.get_response(sample_messages)
assert result == "test response"
assert provider.total_tokens == 0 # Should default to 0
@pytest.mark.asyncio
async def test_get_response_uses_correct_endpoint(self, sample_messages):
"""Test that correct API endpoint is used"""
provider = PerplexityProvider(api_key="test-key")
with patch('httpx.AsyncClient') as mock_client:
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [{"message": {"content": "test"}}],
"usage": {"total_tokens": 10}
}
mock_client = Mock()
mock_client.is_closed = False
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [{"message": {"content": "test"}}],
"usage": {"total_tokens": 10}
}
mock_post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value.post = mock_post
mock_post = AsyncMock(return_value=mock_response)
mock_client.post = mock_post
await provider.get_response(sample_messages)
provider._client = mock_client
call_args = mock_post.call_args
# First positional argument should be the URL
assert call_args[0][0] == "https://api.perplexity.ai/chat/completions"
await provider.get_response(sample_messages)
call_args = mock_post.call_args
# First positional argument should be the URL
assert call_args[0][0] == "https://api.perplexity.ai/chat/completions"
@pytest.mark.asyncio
async def test_get_response_timeout_configuration(self, sample_messages):
@ -399,6 +433,8 @@ class TestPerplexityProvider:
provider = PerplexityProvider(api_key="test-key")
with patch('httpx.AsyncClient') as mock_client_class:
mock_client_instance = Mock()
mock_client_instance.is_closed = False
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
@ -406,11 +442,96 @@ class TestPerplexityProvider:
"usage": {"total_tokens": 10}
}
mock_post = AsyncMock(return_value=mock_response)
mock_client_class.return_value.__aenter__.return_value.post = mock_post
mock_client_instance.post = AsyncMock(return_value=mock_response)
mock_client_class.return_value = mock_client_instance
await provider.get_response(sample_messages)
# Check that AsyncClient was instantiated with timeout
call_args = mock_client_class.call_args
assert call_args[1]["timeout"] == 30.0
@pytest.mark.asyncio
async def test_get_client_creates_client_when_none(self):
"""Test that _get_client creates a new client when none exists"""
provider = PerplexityProvider(api_key="test-key")
with patch('httpx.AsyncClient') as mock_client_class:
mock_client_instance = Mock()
mock_client_instance.is_closed = False
mock_client_class.return_value = mock_client_instance
client = await provider._get_client()
assert client is mock_client_instance
mock_client_class.assert_called_once_with(timeout=30.0)
@pytest.mark.asyncio
async def test_get_client_reuses_existing_client(self):
"""Test that _get_client reuses an existing open client"""
provider = PerplexityProvider(api_key="test-key")
mock_client = Mock()
mock_client.is_closed = False
provider._client = mock_client
with patch('httpx.AsyncClient') as mock_client_class:
client = await provider._get_client()
assert client is mock_client
mock_client_class.assert_not_called()
@pytest.mark.asyncio
async def test_get_client_recreates_closed_client(self):
"""Test that _get_client creates a new client when existing one is closed"""
provider = PerplexityProvider(api_key="test-key")
old_client = Mock()
old_client.is_closed = True
provider._client = old_client
with patch('httpx.AsyncClient') as mock_client_class:
new_client = Mock()
new_client.is_closed = False
mock_client_class.return_value = new_client
client = await provider._get_client()
assert client is new_client
mock_client_class.assert_called_once()
@pytest.mark.asyncio
async def test_close_closes_client(self):
"""Test that close() properly closes the HTTP client"""
provider = PerplexityProvider(api_key="test-key")
mock_client = Mock()
mock_client.is_closed = False
mock_client.aclose = AsyncMock()
provider._client = mock_client
await provider.close()
mock_client.aclose.assert_called_once()
assert provider._client is None
@pytest.mark.asyncio
async def test_close_handles_none_client(self):
"""Test that close() handles case when client is None"""
provider = PerplexityProvider(api_key="test-key")
provider._client = None
# Should not raise any errors
await provider.close()
@pytest.mark.asyncio
async def test_close_handles_already_closed_client(self):
"""Test that close() handles already closed client"""
provider = PerplexityProvider(api_key="test-key")
mock_client = Mock()
mock_client.is_closed = True
provider._client = mock_client
# Should not raise any errors
await provider.close()

View file

@ -22,48 +22,51 @@ class TestFullWorkflow:
detector = MentionDetector(config.bot_name)
ai = PerplexityProvider(api_key=config.perplexity_key, model=config.model)
# Mock API response
with patch('httpx.AsyncClient') as mock_client:
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = mock_perplexity_response
# Mock API client for lazy initialization
mock_client = Mock()
mock_client.is_closed = False
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = mock_perplexity_response
mock_post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value.post = mock_post
mock_client.post = AsyncMock(return_value=mock_response)
# Simulate user message
username = "testuser"
user_message = "@TestBot what's the weather?"
# Inject the mock client
ai._client = mock_client
# 1. Detect mention
assert detector.is_mentioned(user_message)
# Simulate user message
username = "testuser"
user_message = "@TestBot what's the weather?"
# 2. Extract content
content = detector.extract_content(user_message)
assert content == "what's the weather?"
# 1. Detect mention
assert detector.is_mentioned(user_message)
# 3. Load history
history = memory.get_user_history(username, limit=5)
assert len(history) == 0 # First message
# 2. Extract content
content = detector.extract_content(user_message)
assert content == "what's the weather?"
# 4. Build messages for API
messages = [{"role": "system", "content": config.get_system_prompt()}]
messages.extend(memory.format_for_prompt(history))
messages.append({"role": "user", "content": content})
# 3. Load history
history = memory.get_user_history(username, limit=5)
assert len(history) == 0 # First message
# 5. Get AI response
response = await ai.get_response(messages)
assert response == "This is a test response from the AI."
# 4. Build messages for API
messages = [{"role": "system", "content": config.get_system_prompt()}]
messages.extend(memory.format_for_prompt(history))
messages.append({"role": "user", "content": content})
# 6. Save to memory
memory.add_message(username, "user", content)
memory.add_message(username, "assistant", response)
# 5. Get AI response
response = await ai.get_response(messages)
assert response == "This is a test response from the AI."
# 7. Verify history was saved
saved_history = memory.get_user_history(username)
assert len(saved_history) == 2
assert saved_history[0]["content"] == content
assert saved_history[1]["content"] == response
# 6. Save to memory
memory.add_message(username, "user", content)
memory.add_message(username, "assistant", response)
# 7. Verify history was saved
saved_history = memory.get_user_history(username)
assert len(saved_history) == 2
assert saved_history[0]["content"] == content
assert saved_history[1]["content"] == response
@pytest.mark.asyncio
async def test_conversation_context_preserved(self, temp_dir, mock_env_file):

View file

@ -20,8 +20,22 @@ class MentionDetector:
# Create patterns for various mention formats
# Include bot name and all nicknames
all_names = [bot_name] + self.nicknames
self.patterns = []
# Build a single combined pattern for efficient matching (uses alternation)
name_alternatives = "|".join(re.escape(name) for name in all_names)
# Combined pattern that matches any mention format
# This is more efficient than checking multiple patterns separately
combined_mention_pattern = (
rf"(?:@(?:{name_alternatives})\b)|" # @name
rf"(?:\b(?:{name_alternatives})[:!?.,])|" # name: name! etc.
rf"(?:^(?:{name_alternatives})\b)|" # name at start
rf"(?:\b(?:{name_alternatives})\b)" # name anywhere
)
self._mention_pattern = re.compile(combined_mention_pattern, re.IGNORECASE)
# Keep individual patterns list for backward compatibility (tests may use it)
self.patterns = []
for name in all_names:
self.patterns.extend([
rf"@{name}\b", # @name (with word boundary)
@ -29,12 +43,21 @@ class MentionDetector:
rf"^{name}\b", # name at start of message
rf"\b{name}\b", # name anywhere as whole word
])
# Case-insensitive compilation
self.compiled_patterns = [
re.compile(pattern, re.IGNORECASE) for pattern in self.patterns
]
# Pre-compile extraction patterns for each name (more efficient than recompiling each time)
self._extraction_patterns = {}
for name in all_names:
escaped_name = re.escape(name)
self._extraction_patterns[name.lower()] = {
'at_start': re.compile(rf"^@{escaped_name}\b[,:]?\s*", re.IGNORECASE),
'name_start': re.compile(rf"^{escaped_name}\b[,:]?\s*", re.IGNORECASE),
'name_end': re.compile(rf"\s*\b{escaped_name}[,!?.]?\s*$", re.IGNORECASE),
'name_middle': re.compile(rf"\s*\b{escaped_name}[,:!?]\s*", re.IGNORECASE),
}
# Patterns for ambiguous greetings (might be directed at bot)
self.greeting_patterns = [
r"^(hi|hey|hallo|hello|servus|moin)(\s|$|\W)",
@ -78,10 +101,8 @@ class MentionDetector:
if not message:
return False
for pattern in self.compiled_patterns:
if pattern.search(message):
return True
return False
# Use the optimized single combined pattern for faster matching
return bool(self._mention_pattern.search(message))
def is_ambiguous_greeting(self, message):
"""
@ -122,18 +143,21 @@ class MentionDetector:
content = message
all_names = [self.bot_name] + self.nicknames
# Use pre-compiled patterns for better performance
for name in all_names:
# Remove @mention at start
content = re.sub(rf"^@{name}\b[,:]?\s*", "", content, flags=re.IGNORECASE)
patterns = self._extraction_patterns.get(name.lower())
if patterns:
# Remove @mention at start
content = patterns['at_start'].sub("", content)
# Remove name at start with optional punctuation
content = re.sub(rf"^{name}\b[,:]?\s*", "", content, flags=re.IGNORECASE)
# Remove name at start with optional punctuation
content = patterns['name_start'].sub("", content)
# Remove name at end with optional punctuation
content = re.sub(rf"\s*\b{name}[,!?.]?\s*$", "", content, flags=re.IGNORECASE)
# Remove name at end with optional punctuation
content = patterns['name_end'].sub("", content)
# Remove name in middle with punctuation
content = re.sub(rf"\s*\b{name}[,:!?]\s*", " ", content, flags=re.IGNORECASE)
# Remove name in middle with punctuation
content = patterns['name_middle'].sub(" ", content)
return content.strip()