Implement performance improvements for memory caching, HTTP client reuse, and regex optimization

Co-authored-by: Kenearos <86194771+Kenearos@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot] 2026-01-27 17:33:17 +00:00
parent 860e4d5027
commit b72cd9db1c
5 changed files with 451 additions and 246 deletions

View file

@ -30,12 +30,27 @@ class PerplexityProvider:
self.base_url = "https://api.perplexity.ai" self.base_url = "https://api.perplexity.ai"
self.logger = logger or logging.getLogger(__name__) self.logger = logger or logging.getLogger(__name__)
# Reusable HTTP client (created lazily)
self._client: Optional[httpx.AsyncClient] = None
# Statistics # Statistics
self.total_requests = 0 self.total_requests = 0
self.total_tokens = 0 self.total_tokens = 0
self.total_errors = 0 self.total_errors = 0
self.last_response_time = 0 self.last_response_time = 0
async def _get_client(self) -> httpx.AsyncClient:
"""Get or create the HTTP client (lazy initialization)"""
if self._client is None or self._client.is_closed:
self._client = httpx.AsyncClient(timeout=30.0)
return self._client
async def close(self):
"""Close the HTTP client"""
if self._client is not None and not self._client.is_closed:
await self._client.aclose()
self._client = None
async def get_response(self, messages): async def get_response(self, messages):
""" """
Send messages to Perplexity API and get response Send messages to Perplexity API and get response
@ -66,30 +81,30 @@ class PerplexityProvider:
start_time = time.time() start_time = time.time()
try: try:
async with httpx.AsyncClient(timeout=30.0) as client: client = await self._get_client()
response = await client.post( response = await client.post(
f"{self.base_url}/chat/completions", f"{self.base_url}/chat/completions",
json=payload, json=payload,
headers=headers headers=headers
) )
self.last_response_time = time.time() - start_time self.last_response_time = time.time() - start_time
if response.status_code == 200: if response.status_code == 200:
data = response.json() data = response.json()
content = data['choices'][0]['message']['content'] content = data['choices'][0]['message']['content']
tokens_used = data.get('usage', {}).get('total_tokens', 0) tokens_used = data.get('usage', {}).get('total_tokens', 0)
# Update statistics # Update statistics
self.total_requests += 1 self.total_requests += 1
self.total_tokens += tokens_used self.total_tokens += tokens_used
return content return content
else: else:
self.total_errors += 1 self.total_errors += 1
error_msg = f"API Error {response.status_code}: {response.text}" error_msg = f"API Error {response.status_code}: {response.text}"
self.logger.error(error_msg) self.logger.error(error_msg)
return None return None
except httpx.TimeoutException: except httpx.TimeoutException:
self.total_errors += 1 self.total_errors += 1

106
memory.py
View file

@ -27,6 +27,8 @@ class ConversationMemory:
self.max_messages = max_messages self.max_messages = max_messages
self.retention_hours = retention_hours self.retention_hours = retention_hours
self.logger = logger or logging.getLogger(__name__) self.logger = logger or logging.getLogger(__name__)
# In-memory cache to reduce file I/O for frequently accessed users
self._cache: Dict[str, List[Dict]] = {}
def _get_user_file(self, username): def _get_user_file(self, username):
"""Get the file path for a user's conversation history""" """Get the file path for a user's conversation history"""
@ -34,6 +36,59 @@ class ConversationMemory:
safe_username = "".join(c for c in username.lower() if c.isalnum() or c in "._-") safe_username = "".join(c for c in username.lower() if c.isalnum() or c in "._-")
return self.data_dir / f"{safe_username}.json" return self.data_dir / f"{safe_username}.json"
def _load_user_history(self, username):
"""
Load user history from cache or file
Args:
username (str): Twitch username
Returns:
list: List of message dicts or empty list
"""
safe_username = username.lower()
# Check cache first
if safe_username in self._cache:
return self._cache[safe_username]
file_path = self._get_user_file(username)
if not file_path.exists():
self._cache[safe_username] = []
return []
try:
with open(file_path, 'r', encoding='utf-8') as f:
history = json.load(f)
self._cache[safe_username] = history
return history
except Exception as e:
self.logger.error(f"Error loading history for {username}: {e}")
self._cache[safe_username] = []
return []
def _save_user_history(self, username, history):
"""
Save user history to file and update cache
Args:
username (str): Twitch username
history (list): List of message dicts
"""
safe_username = username.lower()
file_path = self._get_user_file(username)
# Update cache
self._cache[safe_username] = history
# Save to file
try:
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(history, f, ensure_ascii=False, indent=2)
except Exception as e:
self.logger.error(f"Error saving history for {username}: {e}")
def get_user_history(self, username, limit=5): def get_user_history(self, username, limit=5):
""" """
Load recent chat history for a user Load recent chat history for a user
@ -45,29 +100,22 @@ class ConversationMemory:
Returns: Returns:
list: List of message dicts with role, content, timestamp list: List of message dicts with role, content, timestamp
""" """
file_path = self._get_user_file(username) history = self._load_user_history(username)
if not file_path.exists(): if not history:
return [] return []
try: # Filter by retention time using list comprehension for better performance
with open(file_path, 'r', encoding='utf-8') as f:
history = json.load(f)
except Exception as e:
self.logger.error(f"Error loading history for {username}: {e}")
return []
# Filter by retention time
cutoff_time = datetime.now() - timedelta(hours=self.retention_hours) cutoff_time = datetime.now() - timedelta(hours=self.retention_hours)
recent = []
for msg in history: def is_recent(msg):
try: try:
msg_time = datetime.fromisoformat(msg['timestamp']) msg_time = datetime.fromisoformat(msg['timestamp'])
if msg_time > cutoff_time: return msg_time > cutoff_time
recent.append(msg)
except (KeyError, ValueError): except (KeyError, ValueError):
continue return False
recent = [msg for msg in history if is_recent(msg)]
# Return only the most recent messages up to limit # Return only the most recent messages up to limit
return recent[-limit:] if recent else [] return recent[-limit:] if recent else []
@ -81,17 +129,8 @@ class ConversationMemory:
role (str): 'user' or 'assistant' role (str): 'user' or 'assistant'
content (str): Message content content (str): Message content
""" """
file_path = self._get_user_file(username) # Load existing history (uses cache if available)
history = self._load_user_history(username)
# Load existing history
history = []
if file_path.exists():
try:
with open(file_path, 'r', encoding='utf-8') as f:
history = json.load(f)
except Exception as e:
self.logger.error(f"Error loading history for {username}: {e}")
history = []
# Add new message # Add new message
history.append({ history.append({
@ -104,12 +143,8 @@ class ConversationMemory:
if len(history) > self.max_messages: if len(history) > self.max_messages:
history = history[-self.max_messages:] history = history[-self.max_messages:]
# Save back to file # Save back to file and update cache
try: self._save_user_history(username, history)
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(history, f, ensure_ascii=False, indent=2)
except Exception as e:
self.logger.error(f"Error saving history for {username}: {e}")
def format_for_prompt(self, history): def format_for_prompt(self, history):
""" """
@ -136,6 +171,13 @@ class ConversationMemory:
Args: Args:
username (str): Twitch username username (str): Twitch username
""" """
safe_username = username.lower()
# Clear from cache
if safe_username in self._cache:
del self._cache[safe_username]
# Clear from disk
file_path = self._get_user_file(username) file_path = self._get_user_file(username)
if file_path.exists(): if file_path.exists():
try: try:

View file

@ -3,7 +3,7 @@ Tests for PerplexityProvider AI API class
""" """
import pytest import pytest
import httpx import httpx
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock, Mock, patch, PropertyMock
from ai_provider import PerplexityProvider from ai_provider import PerplexityProvider
@ -42,28 +42,36 @@ class TestPerplexityProvider:
assert provider.total_errors == 0 assert provider.total_errors == 0
assert provider.last_response_time == 0 assert provider.last_response_time == 0
def test_init_client_starts_as_none(self):
"""Test that HTTP client starts as None (lazy initialization)"""
provider = PerplexityProvider(api_key="test-key")
assert provider._client is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_response_success(self, sample_messages, mock_perplexity_response): async def test_get_response_success(self, sample_messages, mock_perplexity_response):
"""Test successful API response""" """Test successful API response"""
provider = PerplexityProvider(api_key="test-key") provider = PerplexityProvider(api_key="test-key")
# Mock the HTTP client # Create a mock client with is_closed property
with patch('httpx.AsyncClient') as mock_client: mock_client = Mock()
mock_response = Mock() mock_client.is_closed = False
mock_response.status_code = 200 mock_response = Mock()
mock_response.json.return_value = mock_perplexity_response mock_response.status_code = 200
mock_response.json.return_value = mock_perplexity_response
mock_post = AsyncMock(return_value=mock_response) mock_client.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value.post = mock_post
# Call the method # Inject the mock client
result = await provider.get_response(sample_messages) provider._client = mock_client
# Verify result # Call the method
assert result == "This is a test response from the AI." result = await provider.get_response(sample_messages)
assert provider.total_requests == 1
assert provider.total_tokens == 70 # From mock response # Verify result
assert provider.total_errors == 0 assert result == "This is a test response from the AI."
assert provider.total_requests == 1
assert provider.total_tokens == 70 # From mock response
assert provider.total_errors == 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_response_constructs_correct_payload(self, sample_messages): async def test_get_response_constructs_correct_payload(self, sample_messages):
@ -75,130 +83,146 @@ class TestPerplexityProvider:
temperature=0.7 temperature=0.7
) )
with patch('httpx.AsyncClient') as mock_client: mock_client = Mock()
mock_response = Mock() mock_client.is_closed = False
mock_response.status_code = 200 mock_response = Mock()
mock_response.json.return_value = { mock_response.status_code = 200
"choices": [{"message": {"content": "test"}}], mock_response.json.return_value = {
"usage": {"total_tokens": 10} "choices": [{"message": {"content": "test"}}],
} "usage": {"total_tokens": 10}
}
mock_post = AsyncMock(return_value=mock_response) mock_post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value.post = mock_post mock_client.post = mock_post
await provider.get_response(sample_messages) provider._client = mock_client
# Verify the call was made with correct parameters await provider.get_response(sample_messages)
call_args = mock_post.call_args
assert call_args[1]["json"]["model"] == "sonar-pro" # Verify the call was made with correct parameters
assert call_args[1]["json"]["messages"] == sample_messages call_args = mock_post.call_args
assert call_args[1]["json"]["max_tokens"] == 450 assert call_args[1]["json"]["model"] == "sonar-pro"
assert call_args[1]["json"]["temperature"] == 0.7 assert call_args[1]["json"]["messages"] == sample_messages
assert call_args[1]["json"]["max_tokens"] == 450
assert call_args[1]["json"]["temperature"] == 0.7
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_response_includes_auth_header(self, sample_messages): async def test_get_response_includes_auth_header(self, sample_messages):
"""Test that Authorization header is included""" """Test that Authorization header is included"""
provider = PerplexityProvider(api_key="test-secret-key") provider = PerplexityProvider(api_key="test-secret-key")
with patch('httpx.AsyncClient') as mock_client: mock_client = Mock()
mock_response = Mock() mock_client.is_closed = False
mock_response.status_code = 200 mock_response = Mock()
mock_response.json.return_value = { mock_response.status_code = 200
"choices": [{"message": {"content": "test"}}], mock_response.json.return_value = {
"usage": {"total_tokens": 10} "choices": [{"message": {"content": "test"}}],
} "usage": {"total_tokens": 10}
}
mock_post = AsyncMock(return_value=mock_response) mock_post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value.post = mock_post mock_client.post = mock_post
await provider.get_response(sample_messages) provider._client = mock_client
call_args = mock_post.call_args await provider.get_response(sample_messages)
headers = call_args[1]["headers"]
assert headers["Authorization"] == "Bearer test-secret-key" call_args = mock_post.call_args
assert headers["Content-Type"] == "application/json" headers = call_args[1]["headers"]
assert headers["Authorization"] == "Bearer test-secret-key"
assert headers["Content-Type"] == "application/json"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_response_handles_401_error(self, sample_messages): async def test_get_response_handles_401_error(self, sample_messages):
"""Test handling of 401 Unauthorized error""" """Test handling of 401 Unauthorized error"""
provider = PerplexityProvider(api_key="invalid-key") provider = PerplexityProvider(api_key="invalid-key")
with patch('httpx.AsyncClient') as mock_client: mock_client = Mock()
mock_response = Mock() mock_client.is_closed = False
mock_response.status_code = 401 mock_response = Mock()
mock_response.text = "Unauthorized" mock_response.status_code = 401
mock_response.text = "Unauthorized"
mock_post = AsyncMock(return_value=mock_response) mock_client.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value.post = mock_post
result = await provider.get_response(sample_messages) provider._client = mock_client
assert result is None result = await provider.get_response(sample_messages)
assert provider.total_errors == 1
assert provider.total_requests == 0 # Failed requests don't count assert result is None
assert provider.total_errors == 1
assert provider.total_requests == 0 # Failed requests don't count
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_response_handles_500_error(self, sample_messages): async def test_get_response_handles_500_error(self, sample_messages):
"""Test handling of 500 server error""" """Test handling of 500 server error"""
provider = PerplexityProvider(api_key="test-key") provider = PerplexityProvider(api_key="test-key")
with patch('httpx.AsyncClient') as mock_client: mock_client = Mock()
mock_response = Mock() mock_client.is_closed = False
mock_response.status_code = 500 mock_response = Mock()
mock_response.text = "Internal Server Error" mock_response.status_code = 500
mock_response.text = "Internal Server Error"
mock_post = AsyncMock(return_value=mock_response) mock_client.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value.post = mock_post
result = await provider.get_response(sample_messages) provider._client = mock_client
assert result is None result = await provider.get_response(sample_messages)
assert provider.total_errors == 1
assert result is None
assert provider.total_errors == 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_response_handles_timeout(self, sample_messages): async def test_get_response_handles_timeout(self, sample_messages):
"""Test handling of request timeout""" """Test handling of request timeout"""
provider = PerplexityProvider(api_key="test-key") provider = PerplexityProvider(api_key="test-key")
with patch('httpx.AsyncClient') as mock_client: mock_client = Mock()
mock_post = AsyncMock(side_effect=httpx.TimeoutException("Timeout")) mock_client.is_closed = False
mock_client.return_value.__aenter__.return_value.post = mock_post mock_client.post = AsyncMock(side_effect=httpx.TimeoutException("Timeout"))
result = await provider.get_response(sample_messages) provider._client = mock_client
assert result is None result = await provider.get_response(sample_messages)
assert provider.total_errors == 1
assert result is None
assert provider.total_errors == 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_response_handles_network_error(self, sample_messages): async def test_get_response_handles_network_error(self, sample_messages):
"""Test handling of network errors""" """Test handling of network errors"""
provider = PerplexityProvider(api_key="test-key") provider = PerplexityProvider(api_key="test-key")
with patch('httpx.AsyncClient') as mock_client: mock_client = Mock()
mock_post = AsyncMock(side_effect=httpx.NetworkError("Network error")) mock_client.is_closed = False
mock_client.return_value.__aenter__.return_value.post = mock_post mock_client.post = AsyncMock(side_effect=httpx.NetworkError("Network error"))
result = await provider.get_response(sample_messages) provider._client = mock_client
assert result is None result = await provider.get_response(sample_messages)
assert provider.total_errors == 1
assert result is None
assert provider.total_errors == 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_response_tracks_response_time(self, sample_messages, mock_perplexity_response): async def test_get_response_tracks_response_time(self, sample_messages, mock_perplexity_response):
"""Test that response time is tracked""" """Test that response time is tracked"""
provider = PerplexityProvider(api_key="test-key") provider = PerplexityProvider(api_key="test-key")
with patch('httpx.AsyncClient') as mock_client: mock_client = Mock()
mock_response = Mock() mock_client.is_closed = False
mock_response.status_code = 200 mock_response = Mock()
mock_response.json.return_value = mock_perplexity_response mock_response.status_code = 200
mock_response.json.return_value = mock_perplexity_response
mock_post = AsyncMock(return_value=mock_response) mock_client.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value.post = mock_post
await provider.get_response(sample_messages) provider._client = mock_client
assert provider.last_response_time > 0 await provider.get_response(sample_messages)
assert provider.last_response_time > 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_validate_api_key_success(self): async def test_validate_api_key_success(self):
@ -282,56 +306,61 @@ class TestPerplexityProvider:
"""Test statistics after successful requests""" """Test statistics after successful requests"""
provider = PerplexityProvider(api_key="test-key") provider = PerplexityProvider(api_key="test-key")
with patch('httpx.AsyncClient') as mock_client: mock_client = Mock()
mock_response = Mock() mock_client.is_closed = False
mock_response.status_code = 200 mock_response = Mock()
mock_response.json.return_value = mock_perplexity_response mock_response.status_code = 200
mock_response.json.return_value = mock_perplexity_response
mock_post = AsyncMock(return_value=mock_response) mock_client.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value.post = mock_post
# Make 2 successful requests provider._client = mock_client
await provider.get_response(sample_messages)
await provider.get_response(sample_messages)
stats = provider.get_statistics() # Make 2 successful requests
await provider.get_response(sample_messages)
await provider.get_response(sample_messages)
assert stats["total_requests"] == 2 stats = provider.get_statistics()
assert stats["total_tokens"] == 140 # 70 * 2
assert stats["total_errors"] == 0 assert stats["total_requests"] == 2
assert stats["success_rate"] == 100 assert stats["total_tokens"] == 140 # 70 * 2
assert stats["estimated_cost"] > 0 assert stats["total_errors"] == 0
assert stats["success_rate"] == 100
assert stats["estimated_cost"] > 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_statistics_with_errors(self, sample_messages): async def test_get_statistics_with_errors(self, sample_messages):
"""Test statistics calculation with errors""" """Test statistics calculation with errors"""
provider = PerplexityProvider(api_key="test-key") provider = PerplexityProvider(api_key="test-key")
with patch('httpx.AsyncClient') as mock_client: mock_client = Mock()
# First request succeeds mock_client.is_closed = False
mock_response_success = Mock()
mock_response_success.status_code = 200
mock_response_success.json.return_value = {
"choices": [{"message": {"content": "test"}}],
"usage": {"total_tokens": 50}
}
# Second request fails # First request succeeds
mock_response_fail = Mock() mock_response_success = Mock()
mock_response_fail.status_code = 500 mock_response_success.status_code = 200
mock_response_fail.text = "Error" mock_response_success.json.return_value = {
"choices": [{"message": {"content": "test"}}],
"usage": {"total_tokens": 50}
}
mock_post = AsyncMock(side_effect=[mock_response_success, mock_response_fail]) # Second request fails
mock_client.return_value.__aenter__.return_value.post = mock_post mock_response_fail = Mock()
mock_response_fail.status_code = 500
mock_response_fail.text = "Error"
await provider.get_response(sample_messages) mock_client.post = AsyncMock(side_effect=[mock_response_success, mock_response_fail])
await provider.get_response(sample_messages)
stats = provider.get_statistics() provider._client = mock_client
assert stats["total_requests"] == 1 await provider.get_response(sample_messages)
assert stats["total_errors"] == 1 await provider.get_response(sample_messages)
assert stats["success_rate"] == 0 # 0 out of 1 completed request succeeded
stats = provider.get_statistics()
assert stats["total_requests"] == 1
assert stats["total_errors"] == 1
assert stats["success_rate"] == 0 # 0 out of 1 completed request succeeded
def test_reset_statistics(self): def test_reset_statistics(self):
"""Test resetting statistics""" """Test resetting statistics"""
@ -355,43 +384,48 @@ class TestPerplexityProvider:
"""Test handling response without usage data""" """Test handling response without usage data"""
provider = PerplexityProvider(api_key="test-key") provider = PerplexityProvider(api_key="test-key")
with patch('httpx.AsyncClient') as mock_client: mock_client = Mock()
mock_response = Mock() mock_client.is_closed = False
mock_response.status_code = 200 mock_response = Mock()
mock_response.json.return_value = { mock_response.status_code = 200
"choices": [{"message": {"content": "test response"}}] mock_response.json.return_value = {
# No usage field "choices": [{"message": {"content": "test response"}}]
} # No usage field
}
mock_post = AsyncMock(return_value=mock_response) mock_client.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value.post = mock_post
result = await provider.get_response(sample_messages) provider._client = mock_client
assert result == "test response" result = await provider.get_response(sample_messages)
assert provider.total_tokens == 0 # Should default to 0
assert result == "test response"
assert provider.total_tokens == 0 # Should default to 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_response_uses_correct_endpoint(self, sample_messages): async def test_get_response_uses_correct_endpoint(self, sample_messages):
"""Test that correct API endpoint is used""" """Test that correct API endpoint is used"""
provider = PerplexityProvider(api_key="test-key") provider = PerplexityProvider(api_key="test-key")
with patch('httpx.AsyncClient') as mock_client: mock_client = Mock()
mock_response = Mock() mock_client.is_closed = False
mock_response.status_code = 200 mock_response = Mock()
mock_response.json.return_value = { mock_response.status_code = 200
"choices": [{"message": {"content": "test"}}], mock_response.json.return_value = {
"usage": {"total_tokens": 10} "choices": [{"message": {"content": "test"}}],
} "usage": {"total_tokens": 10}
}
mock_post = AsyncMock(return_value=mock_response) mock_post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value.post = mock_post mock_client.post = mock_post
await provider.get_response(sample_messages) provider._client = mock_client
call_args = mock_post.call_args await provider.get_response(sample_messages)
# First positional argument should be the URL
assert call_args[0][0] == "https://api.perplexity.ai/chat/completions" call_args = mock_post.call_args
# First positional argument should be the URL
assert call_args[0][0] == "https://api.perplexity.ai/chat/completions"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_response_timeout_configuration(self, sample_messages): async def test_get_response_timeout_configuration(self, sample_messages):
@ -399,6 +433,8 @@ class TestPerplexityProvider:
provider = PerplexityProvider(api_key="test-key") provider = PerplexityProvider(api_key="test-key")
with patch('httpx.AsyncClient') as mock_client_class: with patch('httpx.AsyncClient') as mock_client_class:
mock_client_instance = Mock()
mock_client_instance.is_closed = False
mock_response = Mock() mock_response = Mock()
mock_response.status_code = 200 mock_response.status_code = 200
mock_response.json.return_value = { mock_response.json.return_value = {
@ -406,11 +442,96 @@ class TestPerplexityProvider:
"usage": {"total_tokens": 10} "usage": {"total_tokens": 10}
} }
mock_post = AsyncMock(return_value=mock_response) mock_client_instance.post = AsyncMock(return_value=mock_response)
mock_client_class.return_value.__aenter__.return_value.post = mock_post mock_client_class.return_value = mock_client_instance
await provider.get_response(sample_messages) await provider.get_response(sample_messages)
# Check that AsyncClient was instantiated with timeout # Check that AsyncClient was instantiated with timeout
call_args = mock_client_class.call_args call_args = mock_client_class.call_args
assert call_args[1]["timeout"] == 30.0 assert call_args[1]["timeout"] == 30.0
@pytest.mark.asyncio
async def test_get_client_creates_client_when_none(self):
"""Test that _get_client creates a new client when none exists"""
provider = PerplexityProvider(api_key="test-key")
with patch('httpx.AsyncClient') as mock_client_class:
mock_client_instance = Mock()
mock_client_instance.is_closed = False
mock_client_class.return_value = mock_client_instance
client = await provider._get_client()
assert client is mock_client_instance
mock_client_class.assert_called_once_with(timeout=30.0)
@pytest.mark.asyncio
async def test_get_client_reuses_existing_client(self):
"""Test that _get_client reuses an existing open client"""
provider = PerplexityProvider(api_key="test-key")
mock_client = Mock()
mock_client.is_closed = False
provider._client = mock_client
with patch('httpx.AsyncClient') as mock_client_class:
client = await provider._get_client()
assert client is mock_client
mock_client_class.assert_not_called()
@pytest.mark.asyncio
async def test_get_client_recreates_closed_client(self):
"""Test that _get_client creates a new client when existing one is closed"""
provider = PerplexityProvider(api_key="test-key")
old_client = Mock()
old_client.is_closed = True
provider._client = old_client
with patch('httpx.AsyncClient') as mock_client_class:
new_client = Mock()
new_client.is_closed = False
mock_client_class.return_value = new_client
client = await provider._get_client()
assert client is new_client
mock_client_class.assert_called_once()
@pytest.mark.asyncio
async def test_close_closes_client(self):
"""Test that close() properly closes the HTTP client"""
provider = PerplexityProvider(api_key="test-key")
mock_client = Mock()
mock_client.is_closed = False
mock_client.aclose = AsyncMock()
provider._client = mock_client
await provider.close()
mock_client.aclose.assert_called_once()
assert provider._client is None
@pytest.mark.asyncio
async def test_close_handles_none_client(self):
"""Test that close() handles case when client is None"""
provider = PerplexityProvider(api_key="test-key")
provider._client = None
# Should not raise any errors
await provider.close()
@pytest.mark.asyncio
async def test_close_handles_already_closed_client(self):
"""Test that close() handles already closed client"""
provider = PerplexityProvider(api_key="test-key")
mock_client = Mock()
mock_client.is_closed = True
provider._client = mock_client
# Should not raise any errors
await provider.close()

View file

@ -22,48 +22,51 @@ class TestFullWorkflow:
detector = MentionDetector(config.bot_name) detector = MentionDetector(config.bot_name)
ai = PerplexityProvider(api_key=config.perplexity_key, model=config.model) ai = PerplexityProvider(api_key=config.perplexity_key, model=config.model)
# Mock API response # Mock API client for lazy initialization
with patch('httpx.AsyncClient') as mock_client: mock_client = Mock()
mock_response = Mock() mock_client.is_closed = False
mock_response.status_code = 200 mock_response = Mock()
mock_response.json.return_value = mock_perplexity_response mock_response.status_code = 200
mock_response.json.return_value = mock_perplexity_response
mock_post = AsyncMock(return_value=mock_response) mock_client.post = AsyncMock(return_value=mock_response)
mock_client.return_value.__aenter__.return_value.post = mock_post
# Simulate user message # Inject the mock client
username = "testuser" ai._client = mock_client
user_message = "@TestBot what's the weather?"
# 1. Detect mention # Simulate user message
assert detector.is_mentioned(user_message) username = "testuser"
user_message = "@TestBot what's the weather?"
# 2. Extract content # 1. Detect mention
content = detector.extract_content(user_message) assert detector.is_mentioned(user_message)
assert content == "what's the weather?"
# 3. Load history # 2. Extract content
history = memory.get_user_history(username, limit=5) content = detector.extract_content(user_message)
assert len(history) == 0 # First message assert content == "what's the weather?"
# 4. Build messages for API # 3. Load history
messages = [{"role": "system", "content": config.get_system_prompt()}] history = memory.get_user_history(username, limit=5)
messages.extend(memory.format_for_prompt(history)) assert len(history) == 0 # First message
messages.append({"role": "user", "content": content})
# 5. Get AI response # 4. Build messages for API
response = await ai.get_response(messages) messages = [{"role": "system", "content": config.get_system_prompt()}]
assert response == "This is a test response from the AI." messages.extend(memory.format_for_prompt(history))
messages.append({"role": "user", "content": content})
# 6. Save to memory # 5. Get AI response
memory.add_message(username, "user", content) response = await ai.get_response(messages)
memory.add_message(username, "assistant", response) assert response == "This is a test response from the AI."
# 7. Verify history was saved # 6. Save to memory
saved_history = memory.get_user_history(username) memory.add_message(username, "user", content)
assert len(saved_history) == 2 memory.add_message(username, "assistant", response)
assert saved_history[0]["content"] == content
assert saved_history[1]["content"] == response # 7. Verify history was saved
saved_history = memory.get_user_history(username)
assert len(saved_history) == 2
assert saved_history[0]["content"] == content
assert saved_history[1]["content"] == response
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_conversation_context_preserved(self, temp_dir, mock_env_file): async def test_conversation_context_preserved(self, temp_dir, mock_env_file):

View file

@ -20,8 +20,22 @@ class MentionDetector:
# Create patterns for various mention formats # Create patterns for various mention formats
# Include bot name and all nicknames # Include bot name and all nicknames
all_names = [bot_name] + self.nicknames all_names = [bot_name] + self.nicknames
self.patterns = []
# Build a single combined pattern for efficient matching (uses alternation)
name_alternatives = "|".join(re.escape(name) for name in all_names)
# Combined pattern that matches any mention format
# This is more efficient than checking multiple patterns separately
combined_mention_pattern = (
rf"(?:@(?:{name_alternatives})\b)|" # @name
rf"(?:\b(?:{name_alternatives})[:!?.,])|" # name: name! etc.
rf"(?:^(?:{name_alternatives})\b)|" # name at start
rf"(?:\b(?:{name_alternatives})\b)" # name anywhere
)
self._mention_pattern = re.compile(combined_mention_pattern, re.IGNORECASE)
# Keep individual patterns list for backward compatibility (tests may use it)
self.patterns = []
for name in all_names: for name in all_names:
self.patterns.extend([ self.patterns.extend([
rf"@{name}\b", # @name (with word boundary) rf"@{name}\b", # @name (with word boundary)
@ -29,12 +43,21 @@ class MentionDetector:
rf"^{name}\b", # name at start of message rf"^{name}\b", # name at start of message
rf"\b{name}\b", # name anywhere as whole word rf"\b{name}\b", # name anywhere as whole word
]) ])
# Case-insensitive compilation
self.compiled_patterns = [ self.compiled_patterns = [
re.compile(pattern, re.IGNORECASE) for pattern in self.patterns re.compile(pattern, re.IGNORECASE) for pattern in self.patterns
] ]
# Pre-compile extraction patterns for each name (more efficient than recompiling each time)
self._extraction_patterns = {}
for name in all_names:
escaped_name = re.escape(name)
self._extraction_patterns[name.lower()] = {
'at_start': re.compile(rf"^@{escaped_name}\b[,:]?\s*", re.IGNORECASE),
'name_start': re.compile(rf"^{escaped_name}\b[,:]?\s*", re.IGNORECASE),
'name_end': re.compile(rf"\s*\b{escaped_name}[,!?.]?\s*$", re.IGNORECASE),
'name_middle': re.compile(rf"\s*\b{escaped_name}[,:!?]\s*", re.IGNORECASE),
}
# Patterns for ambiguous greetings (might be directed at bot) # Patterns for ambiguous greetings (might be directed at bot)
self.greeting_patterns = [ self.greeting_patterns = [
r"^(hi|hey|hallo|hello|servus|moin)(\s|$|\W)", r"^(hi|hey|hallo|hello|servus|moin)(\s|$|\W)",
@ -78,10 +101,8 @@ class MentionDetector:
if not message: if not message:
return False return False
for pattern in self.compiled_patterns: # Use the optimized single combined pattern for faster matching
if pattern.search(message): return bool(self._mention_pattern.search(message))
return True
return False
def is_ambiguous_greeting(self, message): def is_ambiguous_greeting(self, message):
""" """
@ -122,18 +143,21 @@ class MentionDetector:
content = message content = message
all_names = [self.bot_name] + self.nicknames all_names = [self.bot_name] + self.nicknames
# Use pre-compiled patterns for better performance
for name in all_names: for name in all_names:
# Remove @mention at start patterns = self._extraction_patterns.get(name.lower())
content = re.sub(rf"^@{name}\b[,:]?\s*", "", content, flags=re.IGNORECASE) if patterns:
# Remove @mention at start
content = patterns['at_start'].sub("", content)
# Remove name at start with optional punctuation # Remove name at start with optional punctuation
content = re.sub(rf"^{name}\b[,:]?\s*", "", content, flags=re.IGNORECASE) content = patterns['name_start'].sub("", content)
# Remove name at end with optional punctuation # Remove name at end with optional punctuation
content = re.sub(rf"\s*\b{name}[,!?.]?\s*$", "", content, flags=re.IGNORECASE) content = patterns['name_end'].sub("", content)
# Remove name in middle with punctuation # Remove name in middle with punctuation
content = re.sub(rf"\s*\b{name}[,:!?]\s*", " ", content, flags=re.IGNORECASE) content = patterns['name_middle'].sub(" ", content)
return content.strip() return content.strip()