diff --git a/ai_provider.py b/ai_provider.py index 2f6d6e6..49d8855 100644 --- a/ai_provider.py +++ b/ai_provider.py @@ -30,12 +30,27 @@ class PerplexityProvider: self.base_url = "https://api.perplexity.ai" self.logger = logger or logging.getLogger(__name__) + # Reusable HTTP client (created lazily) + self._client: Optional[httpx.AsyncClient] = None + # Statistics self.total_requests = 0 self.total_tokens = 0 self.total_errors = 0 self.last_response_time = 0 + async def _get_client(self) -> httpx.AsyncClient: + """Get or create the HTTP client (lazy initialization)""" + if self._client is None or self._client.is_closed: + self._client = httpx.AsyncClient(timeout=30.0) + return self._client + + async def close(self): + """Close the HTTP client""" + if self._client is not None and not self._client.is_closed: + await self._client.aclose() + self._client = None + async def get_response(self, messages): """ Send messages to Perplexity API and get response @@ -66,30 +81,30 @@ class PerplexityProvider: start_time = time.time() try: - async with httpx.AsyncClient(timeout=30.0) as client: - response = await client.post( - f"{self.base_url}/chat/completions", - json=payload, - headers=headers - ) + client = await self._get_client() + response = await client.post( + f"{self.base_url}/chat/completions", + json=payload, + headers=headers + ) - self.last_response_time = time.time() - start_time + self.last_response_time = time.time() - start_time - if response.status_code == 200: - data = response.json() - content = data['choices'][0]['message']['content'] - tokens_used = data.get('usage', {}).get('total_tokens', 0) + if response.status_code == 200: + data = response.json() + content = data['choices'][0]['message']['content'] + tokens_used = data.get('usage', {}).get('total_tokens', 0) - # Update statistics - self.total_requests += 1 - self.total_tokens += tokens_used + # Update statistics + self.total_requests += 1 + self.total_tokens += tokens_used - return content - else: - self.total_errors += 1 - error_msg = f"API Error {response.status_code}: {response.text}" - self.logger.error(error_msg) - return None + return content + else: + self.total_errors += 1 + error_msg = f"API Error {response.status_code}: {response.text}" + self.logger.error(error_msg) + return None except httpx.TimeoutException: self.total_errors += 1 diff --git a/memory.py b/memory.py index e7c07b4..847bd5a 100644 --- a/memory.py +++ b/memory.py @@ -27,6 +27,8 @@ class ConversationMemory: self.max_messages = max_messages self.retention_hours = retention_hours self.logger = logger or logging.getLogger(__name__) + # In-memory cache to reduce file I/O for frequently accessed users + self._cache: Dict[str, List[Dict]] = {} def _get_user_file(self, username): """Get the file path for a user's conversation history""" @@ -34,6 +36,59 @@ class ConversationMemory: safe_username = "".join(c for c in username.lower() if c.isalnum() or c in "._-") return self.data_dir / f"{safe_username}.json" + def _load_user_history(self, username): + """ + Load user history from cache or file + + Args: + username (str): Twitch username + + Returns: + list: List of message dicts or empty list + """ + safe_username = username.lower() + + # Check cache first + if safe_username in self._cache: + return self._cache[safe_username] + + file_path = self._get_user_file(username) + + if not file_path.exists(): + self._cache[safe_username] = [] + return [] + + try: + with open(file_path, 'r', encoding='utf-8') as f: + history = json.load(f) + self._cache[safe_username] = history + return history + except Exception as e: + self.logger.error(f"Error loading history for {username}: {e}") + self._cache[safe_username] = [] + return [] + + def _save_user_history(self, username, history): + """ + Save user history to file and update cache + + Args: + username (str): Twitch username + history (list): List of message dicts + """ + safe_username = username.lower() + file_path = self._get_user_file(username) + + # Update cache + self._cache[safe_username] = history + + # Save to file + try: + with open(file_path, 'w', encoding='utf-8') as f: + json.dump(history, f, ensure_ascii=False, indent=2) + except Exception as e: + self.logger.error(f"Error saving history for {username}: {e}") + def get_user_history(self, username, limit=5): """ Load recent chat history for a user @@ -45,29 +100,22 @@ class ConversationMemory: Returns: list: List of message dicts with role, content, timestamp """ - file_path = self._get_user_file(username) + history = self._load_user_history(username) - if not file_path.exists(): + if not history: return [] - try: - with open(file_path, 'r', encoding='utf-8') as f: - history = json.load(f) - except Exception as e: - self.logger.error(f"Error loading history for {username}: {e}") - return [] - - # Filter by retention time + # Filter by retention time using list comprehension for better performance cutoff_time = datetime.now() - timedelta(hours=self.retention_hours) - recent = [] - for msg in history: + def is_recent(msg): try: msg_time = datetime.fromisoformat(msg['timestamp']) - if msg_time > cutoff_time: - recent.append(msg) + return msg_time > cutoff_time except (KeyError, ValueError): - continue + return False + + recent = [msg for msg in history if is_recent(msg)] # Return only the most recent messages up to limit return recent[-limit:] if recent else [] @@ -81,17 +129,8 @@ class ConversationMemory: role (str): 'user' or 'assistant' content (str): Message content """ - file_path = self._get_user_file(username) - - # Load existing history - history = [] - if file_path.exists(): - try: - with open(file_path, 'r', encoding='utf-8') as f: - history = json.load(f) - except Exception as e: - self.logger.error(f"Error loading history for {username}: {e}") - history = [] + # Load existing history (uses cache if available) + history = self._load_user_history(username) # Add new message history.append({ @@ -104,12 +143,8 @@ class ConversationMemory: if len(history) > self.max_messages: history = history[-self.max_messages:] - # Save back to file - try: - with open(file_path, 'w', encoding='utf-8') as f: - json.dump(history, f, ensure_ascii=False, indent=2) - except Exception as e: - self.logger.error(f"Error saving history for {username}: {e}") + # Save back to file and update cache + self._save_user_history(username, history) def format_for_prompt(self, history): """ @@ -136,6 +171,13 @@ class ConversationMemory: Args: username (str): Twitch username """ + safe_username = username.lower() + + # Clear from cache + if safe_username in self._cache: + del self._cache[safe_username] + + # Clear from disk file_path = self._get_user_file(username) if file_path.exists(): try: diff --git a/tests/test_ai_provider.py b/tests/test_ai_provider.py index 4ce5fd0..dacff27 100644 --- a/tests/test_ai_provider.py +++ b/tests/test_ai_provider.py @@ -3,7 +3,7 @@ Tests for PerplexityProvider AI API class """ import pytest import httpx -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import AsyncMock, Mock, patch, PropertyMock from ai_provider import PerplexityProvider @@ -42,28 +42,36 @@ class TestPerplexityProvider: assert provider.total_errors == 0 assert provider.last_response_time == 0 + def test_init_client_starts_as_none(self): + """Test that HTTP client starts as None (lazy initialization)""" + provider = PerplexityProvider(api_key="test-key") + assert provider._client is None + @pytest.mark.asyncio async def test_get_response_success(self, sample_messages, mock_perplexity_response): """Test successful API response""" provider = PerplexityProvider(api_key="test-key") - # Mock the HTTP client - with patch('httpx.AsyncClient') as mock_client: - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = mock_perplexity_response + # Create a mock client with is_closed property + mock_client = Mock() + mock_client.is_closed = False + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = mock_perplexity_response - mock_post = AsyncMock(return_value=mock_response) - mock_client.return_value.__aenter__.return_value.post = mock_post + mock_client.post = AsyncMock(return_value=mock_response) - # Call the method - result = await provider.get_response(sample_messages) + # Inject the mock client + provider._client = mock_client - # Verify result - assert result == "This is a test response from the AI." - assert provider.total_requests == 1 - assert provider.total_tokens == 70 # From mock response - assert provider.total_errors == 0 + # Call the method + result = await provider.get_response(sample_messages) + + # Verify result + assert result == "This is a test response from the AI." + assert provider.total_requests == 1 + assert provider.total_tokens == 70 # From mock response + assert provider.total_errors == 0 @pytest.mark.asyncio async def test_get_response_constructs_correct_payload(self, sample_messages): @@ -75,130 +83,146 @@ class TestPerplexityProvider: temperature=0.7 ) - with patch('httpx.AsyncClient') as mock_client: - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "choices": [{"message": {"content": "test"}}], - "usage": {"total_tokens": 10} - } + mock_client = Mock() + mock_client.is_closed = False + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "choices": [{"message": {"content": "test"}}], + "usage": {"total_tokens": 10} + } - mock_post = AsyncMock(return_value=mock_response) - mock_client.return_value.__aenter__.return_value.post = mock_post + mock_post = AsyncMock(return_value=mock_response) + mock_client.post = mock_post - await provider.get_response(sample_messages) + provider._client = mock_client - # Verify the call was made with correct parameters - call_args = mock_post.call_args - assert call_args[1]["json"]["model"] == "sonar-pro" - assert call_args[1]["json"]["messages"] == sample_messages - assert call_args[1]["json"]["max_tokens"] == 450 - assert call_args[1]["json"]["temperature"] == 0.7 + await provider.get_response(sample_messages) + + # Verify the call was made with correct parameters + call_args = mock_post.call_args + assert call_args[1]["json"]["model"] == "sonar-pro" + assert call_args[1]["json"]["messages"] == sample_messages + assert call_args[1]["json"]["max_tokens"] == 450 + assert call_args[1]["json"]["temperature"] == 0.7 @pytest.mark.asyncio async def test_get_response_includes_auth_header(self, sample_messages): """Test that Authorization header is included""" provider = PerplexityProvider(api_key="test-secret-key") - with patch('httpx.AsyncClient') as mock_client: - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "choices": [{"message": {"content": "test"}}], - "usage": {"total_tokens": 10} - } + mock_client = Mock() + mock_client.is_closed = False + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "choices": [{"message": {"content": "test"}}], + "usage": {"total_tokens": 10} + } - mock_post = AsyncMock(return_value=mock_response) - mock_client.return_value.__aenter__.return_value.post = mock_post + mock_post = AsyncMock(return_value=mock_response) + mock_client.post = mock_post - await provider.get_response(sample_messages) + provider._client = mock_client - call_args = mock_post.call_args - headers = call_args[1]["headers"] - assert headers["Authorization"] == "Bearer test-secret-key" - assert headers["Content-Type"] == "application/json" + await provider.get_response(sample_messages) + + call_args = mock_post.call_args + headers = call_args[1]["headers"] + assert headers["Authorization"] == "Bearer test-secret-key" + assert headers["Content-Type"] == "application/json" @pytest.mark.asyncio async def test_get_response_handles_401_error(self, sample_messages): """Test handling of 401 Unauthorized error""" provider = PerplexityProvider(api_key="invalid-key") - with patch('httpx.AsyncClient') as mock_client: - mock_response = Mock() - mock_response.status_code = 401 - mock_response.text = "Unauthorized" + mock_client = Mock() + mock_client.is_closed = False + mock_response = Mock() + mock_response.status_code = 401 + mock_response.text = "Unauthorized" - mock_post = AsyncMock(return_value=mock_response) - mock_client.return_value.__aenter__.return_value.post = mock_post + mock_client.post = AsyncMock(return_value=mock_response) - result = await provider.get_response(sample_messages) + provider._client = mock_client - assert result is None - assert provider.total_errors == 1 - assert provider.total_requests == 0 # Failed requests don't count + result = await provider.get_response(sample_messages) + + assert result is None + assert provider.total_errors == 1 + assert provider.total_requests == 0 # Failed requests don't count @pytest.mark.asyncio async def test_get_response_handles_500_error(self, sample_messages): """Test handling of 500 server error""" provider = PerplexityProvider(api_key="test-key") - with patch('httpx.AsyncClient') as mock_client: - mock_response = Mock() - mock_response.status_code = 500 - mock_response.text = "Internal Server Error" + mock_client = Mock() + mock_client.is_closed = False + mock_response = Mock() + mock_response.status_code = 500 + mock_response.text = "Internal Server Error" - mock_post = AsyncMock(return_value=mock_response) - mock_client.return_value.__aenter__.return_value.post = mock_post + mock_client.post = AsyncMock(return_value=mock_response) - result = await provider.get_response(sample_messages) + provider._client = mock_client - assert result is None - assert provider.total_errors == 1 + result = await provider.get_response(sample_messages) + + assert result is None + assert provider.total_errors == 1 @pytest.mark.asyncio async def test_get_response_handles_timeout(self, sample_messages): """Test handling of request timeout""" provider = PerplexityProvider(api_key="test-key") - with patch('httpx.AsyncClient') as mock_client: - mock_post = AsyncMock(side_effect=httpx.TimeoutException("Timeout")) - mock_client.return_value.__aenter__.return_value.post = mock_post + mock_client = Mock() + mock_client.is_closed = False + mock_client.post = AsyncMock(side_effect=httpx.TimeoutException("Timeout")) - result = await provider.get_response(sample_messages) + provider._client = mock_client - assert result is None - assert provider.total_errors == 1 + result = await provider.get_response(sample_messages) + + assert result is None + assert provider.total_errors == 1 @pytest.mark.asyncio async def test_get_response_handles_network_error(self, sample_messages): """Test handling of network errors""" provider = PerplexityProvider(api_key="test-key") - with patch('httpx.AsyncClient') as mock_client: - mock_post = AsyncMock(side_effect=httpx.NetworkError("Network error")) - mock_client.return_value.__aenter__.return_value.post = mock_post + mock_client = Mock() + mock_client.is_closed = False + mock_client.post = AsyncMock(side_effect=httpx.NetworkError("Network error")) - result = await provider.get_response(sample_messages) + provider._client = mock_client - assert result is None - assert provider.total_errors == 1 + result = await provider.get_response(sample_messages) + + assert result is None + assert provider.total_errors == 1 @pytest.mark.asyncio async def test_get_response_tracks_response_time(self, sample_messages, mock_perplexity_response): """Test that response time is tracked""" provider = PerplexityProvider(api_key="test-key") - with patch('httpx.AsyncClient') as mock_client: - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = mock_perplexity_response + mock_client = Mock() + mock_client.is_closed = False + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = mock_perplexity_response - mock_post = AsyncMock(return_value=mock_response) - mock_client.return_value.__aenter__.return_value.post = mock_post + mock_client.post = AsyncMock(return_value=mock_response) - await provider.get_response(sample_messages) + provider._client = mock_client - assert provider.last_response_time > 0 + await provider.get_response(sample_messages) + + assert provider.last_response_time > 0 @pytest.mark.asyncio async def test_validate_api_key_success(self): @@ -282,56 +306,61 @@ class TestPerplexityProvider: """Test statistics after successful requests""" provider = PerplexityProvider(api_key="test-key") - with patch('httpx.AsyncClient') as mock_client: - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = mock_perplexity_response + mock_client = Mock() + mock_client.is_closed = False + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = mock_perplexity_response - mock_post = AsyncMock(return_value=mock_response) - mock_client.return_value.__aenter__.return_value.post = mock_post + mock_client.post = AsyncMock(return_value=mock_response) - # Make 2 successful requests - await provider.get_response(sample_messages) - await provider.get_response(sample_messages) + provider._client = mock_client - stats = provider.get_statistics() + # Make 2 successful requests + await provider.get_response(sample_messages) + await provider.get_response(sample_messages) - assert stats["total_requests"] == 2 - assert stats["total_tokens"] == 140 # 70 * 2 - assert stats["total_errors"] == 0 - assert stats["success_rate"] == 100 - assert stats["estimated_cost"] > 0 + stats = provider.get_statistics() + + assert stats["total_requests"] == 2 + assert stats["total_tokens"] == 140 # 70 * 2 + assert stats["total_errors"] == 0 + assert stats["success_rate"] == 100 + assert stats["estimated_cost"] > 0 @pytest.mark.asyncio async def test_get_statistics_with_errors(self, sample_messages): """Test statistics calculation with errors""" provider = PerplexityProvider(api_key="test-key") - with patch('httpx.AsyncClient') as mock_client: - # First request succeeds - mock_response_success = Mock() - mock_response_success.status_code = 200 - mock_response_success.json.return_value = { - "choices": [{"message": {"content": "test"}}], - "usage": {"total_tokens": 50} - } + mock_client = Mock() + mock_client.is_closed = False - # Second request fails - mock_response_fail = Mock() - mock_response_fail.status_code = 500 - mock_response_fail.text = "Error" + # First request succeeds + mock_response_success = Mock() + mock_response_success.status_code = 200 + mock_response_success.json.return_value = { + "choices": [{"message": {"content": "test"}}], + "usage": {"total_tokens": 50} + } - mock_post = AsyncMock(side_effect=[mock_response_success, mock_response_fail]) - mock_client.return_value.__aenter__.return_value.post = mock_post + # Second request fails + mock_response_fail = Mock() + mock_response_fail.status_code = 500 + mock_response_fail.text = "Error" - await provider.get_response(sample_messages) - await provider.get_response(sample_messages) + mock_client.post = AsyncMock(side_effect=[mock_response_success, mock_response_fail]) - stats = provider.get_statistics() + provider._client = mock_client - assert stats["total_requests"] == 1 - assert stats["total_errors"] == 1 - assert stats["success_rate"] == 0 # 0 out of 1 completed request succeeded + await provider.get_response(sample_messages) + await provider.get_response(sample_messages) + + stats = provider.get_statistics() + + assert stats["total_requests"] == 1 + assert stats["total_errors"] == 1 + assert stats["success_rate"] == 0 # 0 out of 1 completed request succeeded def test_reset_statistics(self): """Test resetting statistics""" @@ -355,43 +384,48 @@ class TestPerplexityProvider: """Test handling response without usage data""" provider = PerplexityProvider(api_key="test-key") - with patch('httpx.AsyncClient') as mock_client: - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "choices": [{"message": {"content": "test response"}}] - # No usage field - } + mock_client = Mock() + mock_client.is_closed = False + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "choices": [{"message": {"content": "test response"}}] + # No usage field + } - mock_post = AsyncMock(return_value=mock_response) - mock_client.return_value.__aenter__.return_value.post = mock_post + mock_client.post = AsyncMock(return_value=mock_response) - result = await provider.get_response(sample_messages) + provider._client = mock_client - assert result == "test response" - assert provider.total_tokens == 0 # Should default to 0 + result = await provider.get_response(sample_messages) + + assert result == "test response" + assert provider.total_tokens == 0 # Should default to 0 @pytest.mark.asyncio async def test_get_response_uses_correct_endpoint(self, sample_messages): """Test that correct API endpoint is used""" provider = PerplexityProvider(api_key="test-key") - with patch('httpx.AsyncClient') as mock_client: - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "choices": [{"message": {"content": "test"}}], - "usage": {"total_tokens": 10} - } + mock_client = Mock() + mock_client.is_closed = False + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "choices": [{"message": {"content": "test"}}], + "usage": {"total_tokens": 10} + } - mock_post = AsyncMock(return_value=mock_response) - mock_client.return_value.__aenter__.return_value.post = mock_post + mock_post = AsyncMock(return_value=mock_response) + mock_client.post = mock_post - await provider.get_response(sample_messages) + provider._client = mock_client - call_args = mock_post.call_args - # First positional argument should be the URL - assert call_args[0][0] == "https://api.perplexity.ai/chat/completions" + await provider.get_response(sample_messages) + + call_args = mock_post.call_args + # First positional argument should be the URL + assert call_args[0][0] == "https://api.perplexity.ai/chat/completions" @pytest.mark.asyncio async def test_get_response_timeout_configuration(self, sample_messages): @@ -399,6 +433,8 @@ class TestPerplexityProvider: provider = PerplexityProvider(api_key="test-key") with patch('httpx.AsyncClient') as mock_client_class: + mock_client_instance = Mock() + mock_client_instance.is_closed = False mock_response = Mock() mock_response.status_code = 200 mock_response.json.return_value = { @@ -406,11 +442,96 @@ class TestPerplexityProvider: "usage": {"total_tokens": 10} } - mock_post = AsyncMock(return_value=mock_response) - mock_client_class.return_value.__aenter__.return_value.post = mock_post + mock_client_instance.post = AsyncMock(return_value=mock_response) + mock_client_class.return_value = mock_client_instance await provider.get_response(sample_messages) # Check that AsyncClient was instantiated with timeout call_args = mock_client_class.call_args assert call_args[1]["timeout"] == 30.0 + + @pytest.mark.asyncio + async def test_get_client_creates_client_when_none(self): + """Test that _get_client creates a new client when none exists""" + provider = PerplexityProvider(api_key="test-key") + + with patch('httpx.AsyncClient') as mock_client_class: + mock_client_instance = Mock() + mock_client_instance.is_closed = False + mock_client_class.return_value = mock_client_instance + + client = await provider._get_client() + + assert client is mock_client_instance + mock_client_class.assert_called_once_with(timeout=30.0) + + @pytest.mark.asyncio + async def test_get_client_reuses_existing_client(self): + """Test that _get_client reuses an existing open client""" + provider = PerplexityProvider(api_key="test-key") + + mock_client = Mock() + mock_client.is_closed = False + provider._client = mock_client + + with patch('httpx.AsyncClient') as mock_client_class: + client = await provider._get_client() + + assert client is mock_client + mock_client_class.assert_not_called() + + @pytest.mark.asyncio + async def test_get_client_recreates_closed_client(self): + """Test that _get_client creates a new client when existing one is closed""" + provider = PerplexityProvider(api_key="test-key") + + old_client = Mock() + old_client.is_closed = True + provider._client = old_client + + with patch('httpx.AsyncClient') as mock_client_class: + new_client = Mock() + new_client.is_closed = False + mock_client_class.return_value = new_client + + client = await provider._get_client() + + assert client is new_client + mock_client_class.assert_called_once() + + @pytest.mark.asyncio + async def test_close_closes_client(self): + """Test that close() properly closes the HTTP client""" + provider = PerplexityProvider(api_key="test-key") + + mock_client = Mock() + mock_client.is_closed = False + mock_client.aclose = AsyncMock() + provider._client = mock_client + + await provider.close() + + mock_client.aclose.assert_called_once() + assert provider._client is None + + @pytest.mark.asyncio + async def test_close_handles_none_client(self): + """Test that close() handles case when client is None""" + provider = PerplexityProvider(api_key="test-key") + provider._client = None + + # Should not raise any errors + await provider.close() + + @pytest.mark.asyncio + async def test_close_handles_already_closed_client(self): + """Test that close() handles already closed client""" + provider = PerplexityProvider(api_key="test-key") + + mock_client = Mock() + mock_client.is_closed = True + provider._client = mock_client + + # Should not raise any errors + await provider.close() diff --git a/tests/test_integration.py b/tests/test_integration.py index 1ad9087..08c2708 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -22,48 +22,51 @@ class TestFullWorkflow: detector = MentionDetector(config.bot_name) ai = PerplexityProvider(api_key=config.perplexity_key, model=config.model) - # Mock API response - with patch('httpx.AsyncClient') as mock_client: - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = mock_perplexity_response + # Mock API client for lazy initialization + mock_client = Mock() + mock_client.is_closed = False + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = mock_perplexity_response - mock_post = AsyncMock(return_value=mock_response) - mock_client.return_value.__aenter__.return_value.post = mock_post + mock_client.post = AsyncMock(return_value=mock_response) - # Simulate user message - username = "testuser" - user_message = "@TestBot what's the weather?" + # Inject the mock client + ai._client = mock_client - # 1. Detect mention - assert detector.is_mentioned(user_message) + # Simulate user message + username = "testuser" + user_message = "@TestBot what's the weather?" - # 2. Extract content - content = detector.extract_content(user_message) - assert content == "what's the weather?" + # 1. Detect mention + assert detector.is_mentioned(user_message) - # 3. Load history - history = memory.get_user_history(username, limit=5) - assert len(history) == 0 # First message + # 2. Extract content + content = detector.extract_content(user_message) + assert content == "what's the weather?" - # 4. Build messages for API - messages = [{"role": "system", "content": config.get_system_prompt()}] - messages.extend(memory.format_for_prompt(history)) - messages.append({"role": "user", "content": content}) + # 3. Load history + history = memory.get_user_history(username, limit=5) + assert len(history) == 0 # First message - # 5. Get AI response - response = await ai.get_response(messages) - assert response == "This is a test response from the AI." + # 4. Build messages for API + messages = [{"role": "system", "content": config.get_system_prompt()}] + messages.extend(memory.format_for_prompt(history)) + messages.append({"role": "user", "content": content}) - # 6. Save to memory - memory.add_message(username, "user", content) - memory.add_message(username, "assistant", response) + # 5. Get AI response + response = await ai.get_response(messages) + assert response == "This is a test response from the AI." - # 7. Verify history was saved - saved_history = memory.get_user_history(username) - assert len(saved_history) == 2 - assert saved_history[0]["content"] == content - assert saved_history[1]["content"] == response + # 6. Save to memory + memory.add_message(username, "user", content) + memory.add_message(username, "assistant", response) + + # 7. Verify history was saved + saved_history = memory.get_user_history(username) + assert len(saved_history) == 2 + assert saved_history[0]["content"] == content + assert saved_history[1]["content"] == response @pytest.mark.asyncio async def test_conversation_context_preserved(self, temp_dir, mock_env_file): diff --git a/utils.py b/utils.py index 674a170..ee4c24d 100644 --- a/utils.py +++ b/utils.py @@ -20,8 +20,22 @@ class MentionDetector: # Create patterns for various mention formats # Include bot name and all nicknames all_names = [bot_name] + self.nicknames - self.patterns = [] + # Build a single combined pattern for efficient matching (uses alternation) + name_alternatives = "|".join(re.escape(name) for name in all_names) + + # Combined pattern that matches any mention format + # This is more efficient than checking multiple patterns separately + combined_mention_pattern = ( + rf"(?:@(?:{name_alternatives})\b)|" # @name + rf"(?:\b(?:{name_alternatives})[:!?.,])|" # name: name! etc. + rf"(?:^(?:{name_alternatives})\b)|" # name at start + rf"(?:\b(?:{name_alternatives})\b)" # name anywhere + ) + self._mention_pattern = re.compile(combined_mention_pattern, re.IGNORECASE) + + # Keep individual patterns list for backward compatibility (tests may use it) + self.patterns = [] for name in all_names: self.patterns.extend([ rf"@{name}\b", # @name (with word boundary) @@ -29,12 +43,21 @@ class MentionDetector: rf"^{name}\b", # name at start of message rf"\b{name}\b", # name anywhere as whole word ]) - - # Case-insensitive compilation self.compiled_patterns = [ re.compile(pattern, re.IGNORECASE) for pattern in self.patterns ] + # Pre-compile extraction patterns for each name (more efficient than recompiling each time) + self._extraction_patterns = {} + for name in all_names: + escaped_name = re.escape(name) + self._extraction_patterns[name.lower()] = { + 'at_start': re.compile(rf"^@{escaped_name}\b[,:]?\s*", re.IGNORECASE), + 'name_start': re.compile(rf"^{escaped_name}\b[,:]?\s*", re.IGNORECASE), + 'name_end': re.compile(rf"\s*\b{escaped_name}[,!?.]?\s*$", re.IGNORECASE), + 'name_middle': re.compile(rf"\s*\b{escaped_name}[,:!?]\s*", re.IGNORECASE), + } + # Patterns for ambiguous greetings (might be directed at bot) self.greeting_patterns = [ r"^(hi|hey|hallo|hello|servus|moin)(\s|$|\W)", @@ -78,10 +101,8 @@ class MentionDetector: if not message: return False - for pattern in self.compiled_patterns: - if pattern.search(message): - return True - return False + # Use the optimized single combined pattern for faster matching + return bool(self._mention_pattern.search(message)) def is_ambiguous_greeting(self, message): """ @@ -122,18 +143,21 @@ class MentionDetector: content = message all_names = [self.bot_name] + self.nicknames + # Use pre-compiled patterns for better performance for name in all_names: - # Remove @mention at start - content = re.sub(rf"^@{name}\b[,:]?\s*", "", content, flags=re.IGNORECASE) + patterns = self._extraction_patterns.get(name.lower()) + if patterns: + # Remove @mention at start + content = patterns['at_start'].sub("", content) - # Remove name at start with optional punctuation - content = re.sub(rf"^{name}\b[,:]?\s*", "", content, flags=re.IGNORECASE) + # Remove name at start with optional punctuation + content = patterns['name_start'].sub("", content) - # Remove name at end with optional punctuation - content = re.sub(rf"\s*\b{name}[,!?.]?\s*$", "", content, flags=re.IGNORECASE) + # Remove name at end with optional punctuation + content = patterns['name_end'].sub("", content) - # Remove name in middle with punctuation - content = re.sub(rf"\s*\b{name}[,:!?]\s*", " ", content, flags=re.IGNORECASE) + # Remove name in middle with punctuation + content = patterns['name_middle'].sub(" ", content) return content.strip()