Implement performance improvements for memory caching, HTTP client reuse, and regex optimization
Co-authored-by: Kenearos <86194771+Kenearos@users.noreply.github.com>
This commit is contained in:
parent
860e4d5027
commit
b72cd9db1c
5 changed files with 451 additions and 246 deletions
|
|
@ -30,12 +30,27 @@ class PerplexityProvider:
|
||||||
self.base_url = "https://api.perplexity.ai"
|
self.base_url = "https://api.perplexity.ai"
|
||||||
self.logger = logger or logging.getLogger(__name__)
|
self.logger = logger or logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Reusable HTTP client (created lazily)
|
||||||
|
self._client: Optional[httpx.AsyncClient] = None
|
||||||
|
|
||||||
# Statistics
|
# Statistics
|
||||||
self.total_requests = 0
|
self.total_requests = 0
|
||||||
self.total_tokens = 0
|
self.total_tokens = 0
|
||||||
self.total_errors = 0
|
self.total_errors = 0
|
||||||
self.last_response_time = 0
|
self.last_response_time = 0
|
||||||
|
|
||||||
|
async def _get_client(self) -> httpx.AsyncClient:
|
||||||
|
"""Get or create the HTTP client (lazy initialization)"""
|
||||||
|
if self._client is None or self._client.is_closed:
|
||||||
|
self._client = httpx.AsyncClient(timeout=30.0)
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""Close the HTTP client"""
|
||||||
|
if self._client is not None and not self._client.is_closed:
|
||||||
|
await self._client.aclose()
|
||||||
|
self._client = None
|
||||||
|
|
||||||
async def get_response(self, messages):
|
async def get_response(self, messages):
|
||||||
"""
|
"""
|
||||||
Send messages to Perplexity API and get response
|
Send messages to Perplexity API and get response
|
||||||
|
|
@ -66,7 +81,7 @@ class PerplexityProvider:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
client = await self._get_client()
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
f"{self.base_url}/chat/completions",
|
f"{self.base_url}/chat/completions",
|
||||||
json=payload,
|
json=payload,
|
||||||
|
|
|
||||||
106
memory.py
106
memory.py
|
|
@ -27,6 +27,8 @@ class ConversationMemory:
|
||||||
self.max_messages = max_messages
|
self.max_messages = max_messages
|
||||||
self.retention_hours = retention_hours
|
self.retention_hours = retention_hours
|
||||||
self.logger = logger or logging.getLogger(__name__)
|
self.logger = logger or logging.getLogger(__name__)
|
||||||
|
# In-memory cache to reduce file I/O for frequently accessed users
|
||||||
|
self._cache: Dict[str, List[Dict]] = {}
|
||||||
|
|
||||||
def _get_user_file(self, username):
|
def _get_user_file(self, username):
|
||||||
"""Get the file path for a user's conversation history"""
|
"""Get the file path for a user's conversation history"""
|
||||||
|
|
@ -34,6 +36,59 @@ class ConversationMemory:
|
||||||
safe_username = "".join(c for c in username.lower() if c.isalnum() or c in "._-")
|
safe_username = "".join(c for c in username.lower() if c.isalnum() or c in "._-")
|
||||||
return self.data_dir / f"{safe_username}.json"
|
return self.data_dir / f"{safe_username}.json"
|
||||||
|
|
||||||
|
def _load_user_history(self, username):
|
||||||
|
"""
|
||||||
|
Load user history from cache or file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
username (str): Twitch username
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: List of message dicts or empty list
|
||||||
|
"""
|
||||||
|
safe_username = username.lower()
|
||||||
|
|
||||||
|
# Check cache first
|
||||||
|
if safe_username in self._cache:
|
||||||
|
return self._cache[safe_username]
|
||||||
|
|
||||||
|
file_path = self._get_user_file(username)
|
||||||
|
|
||||||
|
if not file_path.exists():
|
||||||
|
self._cache[safe_username] = []
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
history = json.load(f)
|
||||||
|
self._cache[safe_username] = history
|
||||||
|
return history
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error loading history for {username}: {e}")
|
||||||
|
self._cache[safe_username] = []
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _save_user_history(self, username, history):
|
||||||
|
"""
|
||||||
|
Save user history to file and update cache
|
||||||
|
|
||||||
|
Args:
|
||||||
|
username (str): Twitch username
|
||||||
|
history (list): List of message dicts
|
||||||
|
"""
|
||||||
|
safe_username = username.lower()
|
||||||
|
file_path = self._get_user_file(username)
|
||||||
|
|
||||||
|
# Update cache
|
||||||
|
self._cache[safe_username] = history
|
||||||
|
|
||||||
|
# Save to file
|
||||||
|
try:
|
||||||
|
with open(file_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(history, f, ensure_ascii=False, indent=2)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error saving history for {username}: {e}")
|
||||||
|
|
||||||
def get_user_history(self, username, limit=5):
|
def get_user_history(self, username, limit=5):
|
||||||
"""
|
"""
|
||||||
Load recent chat history for a user
|
Load recent chat history for a user
|
||||||
|
|
@ -45,29 +100,22 @@ class ConversationMemory:
|
||||||
Returns:
|
Returns:
|
||||||
list: List of message dicts with role, content, timestamp
|
list: List of message dicts with role, content, timestamp
|
||||||
"""
|
"""
|
||||||
file_path = self._get_user_file(username)
|
history = self._load_user_history(username)
|
||||||
|
|
||||||
if not file_path.exists():
|
if not history:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
try:
|
# Filter by retention time using list comprehension for better performance
|
||||||
with open(file_path, 'r', encoding='utf-8') as f:
|
|
||||||
history = json.load(f)
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Error loading history for {username}: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Filter by retention time
|
|
||||||
cutoff_time = datetime.now() - timedelta(hours=self.retention_hours)
|
cutoff_time = datetime.now() - timedelta(hours=self.retention_hours)
|
||||||
recent = []
|
|
||||||
|
|
||||||
for msg in history:
|
def is_recent(msg):
|
||||||
try:
|
try:
|
||||||
msg_time = datetime.fromisoformat(msg['timestamp'])
|
msg_time = datetime.fromisoformat(msg['timestamp'])
|
||||||
if msg_time > cutoff_time:
|
return msg_time > cutoff_time
|
||||||
recent.append(msg)
|
|
||||||
except (KeyError, ValueError):
|
except (KeyError, ValueError):
|
||||||
continue
|
return False
|
||||||
|
|
||||||
|
recent = [msg for msg in history if is_recent(msg)]
|
||||||
|
|
||||||
# Return only the most recent messages up to limit
|
# Return only the most recent messages up to limit
|
||||||
return recent[-limit:] if recent else []
|
return recent[-limit:] if recent else []
|
||||||
|
|
@ -81,17 +129,8 @@ class ConversationMemory:
|
||||||
role (str): 'user' or 'assistant'
|
role (str): 'user' or 'assistant'
|
||||||
content (str): Message content
|
content (str): Message content
|
||||||
"""
|
"""
|
||||||
file_path = self._get_user_file(username)
|
# Load existing history (uses cache if available)
|
||||||
|
history = self._load_user_history(username)
|
||||||
# Load existing history
|
|
||||||
history = []
|
|
||||||
if file_path.exists():
|
|
||||||
try:
|
|
||||||
with open(file_path, 'r', encoding='utf-8') as f:
|
|
||||||
history = json.load(f)
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Error loading history for {username}: {e}")
|
|
||||||
history = []
|
|
||||||
|
|
||||||
# Add new message
|
# Add new message
|
||||||
history.append({
|
history.append({
|
||||||
|
|
@ -104,12 +143,8 @@ class ConversationMemory:
|
||||||
if len(history) > self.max_messages:
|
if len(history) > self.max_messages:
|
||||||
history = history[-self.max_messages:]
|
history = history[-self.max_messages:]
|
||||||
|
|
||||||
# Save back to file
|
# Save back to file and update cache
|
||||||
try:
|
self._save_user_history(username, history)
|
||||||
with open(file_path, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(history, f, ensure_ascii=False, indent=2)
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Error saving history for {username}: {e}")
|
|
||||||
|
|
||||||
def format_for_prompt(self, history):
|
def format_for_prompt(self, history):
|
||||||
"""
|
"""
|
||||||
|
|
@ -136,6 +171,13 @@ class ConversationMemory:
|
||||||
Args:
|
Args:
|
||||||
username (str): Twitch username
|
username (str): Twitch username
|
||||||
"""
|
"""
|
||||||
|
safe_username = username.lower()
|
||||||
|
|
||||||
|
# Clear from cache
|
||||||
|
if safe_username in self._cache:
|
||||||
|
del self._cache[safe_username]
|
||||||
|
|
||||||
|
# Clear from disk
|
||||||
file_path = self._get_user_file(username)
|
file_path = self._get_user_file(username)
|
||||||
if file_path.exists():
|
if file_path.exists():
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ Tests for PerplexityProvider AI API class
|
||||||
"""
|
"""
|
||||||
import pytest
|
import pytest
|
||||||
import httpx
|
import httpx
|
||||||
from unittest.mock import AsyncMock, Mock, patch
|
from unittest.mock import AsyncMock, Mock, patch, PropertyMock
|
||||||
from ai_provider import PerplexityProvider
|
from ai_provider import PerplexityProvider
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -42,19 +42,27 @@ class TestPerplexityProvider:
|
||||||
assert provider.total_errors == 0
|
assert provider.total_errors == 0
|
||||||
assert provider.last_response_time == 0
|
assert provider.last_response_time == 0
|
||||||
|
|
||||||
|
def test_init_client_starts_as_none(self):
|
||||||
|
"""Test that HTTP client starts as None (lazy initialization)"""
|
||||||
|
provider = PerplexityProvider(api_key="test-key")
|
||||||
|
assert provider._client is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_response_success(self, sample_messages, mock_perplexity_response):
|
async def test_get_response_success(self, sample_messages, mock_perplexity_response):
|
||||||
"""Test successful API response"""
|
"""Test successful API response"""
|
||||||
provider = PerplexityProvider(api_key="test-key")
|
provider = PerplexityProvider(api_key="test-key")
|
||||||
|
|
||||||
# Mock the HTTP client
|
# Create a mock client with is_closed property
|
||||||
with patch('httpx.AsyncClient') as mock_client:
|
mock_client = Mock()
|
||||||
|
mock_client.is_closed = False
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.status_code = 200
|
mock_response.status_code = 200
|
||||||
mock_response.json.return_value = mock_perplexity_response
|
mock_response.json.return_value = mock_perplexity_response
|
||||||
|
|
||||||
mock_post = AsyncMock(return_value=mock_response)
|
mock_client.post = AsyncMock(return_value=mock_response)
|
||||||
mock_client.return_value.__aenter__.return_value.post = mock_post
|
|
||||||
|
# Inject the mock client
|
||||||
|
provider._client = mock_client
|
||||||
|
|
||||||
# Call the method
|
# Call the method
|
||||||
result = await provider.get_response(sample_messages)
|
result = await provider.get_response(sample_messages)
|
||||||
|
|
@ -75,7 +83,8 @@ class TestPerplexityProvider:
|
||||||
temperature=0.7
|
temperature=0.7
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch('httpx.AsyncClient') as mock_client:
|
mock_client = Mock()
|
||||||
|
mock_client.is_closed = False
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.status_code = 200
|
mock_response.status_code = 200
|
||||||
mock_response.json.return_value = {
|
mock_response.json.return_value = {
|
||||||
|
|
@ -84,7 +93,9 @@ class TestPerplexityProvider:
|
||||||
}
|
}
|
||||||
|
|
||||||
mock_post = AsyncMock(return_value=mock_response)
|
mock_post = AsyncMock(return_value=mock_response)
|
||||||
mock_client.return_value.__aenter__.return_value.post = mock_post
|
mock_client.post = mock_post
|
||||||
|
|
||||||
|
provider._client = mock_client
|
||||||
|
|
||||||
await provider.get_response(sample_messages)
|
await provider.get_response(sample_messages)
|
||||||
|
|
||||||
|
|
@ -100,7 +111,8 @@ class TestPerplexityProvider:
|
||||||
"""Test that Authorization header is included"""
|
"""Test that Authorization header is included"""
|
||||||
provider = PerplexityProvider(api_key="test-secret-key")
|
provider = PerplexityProvider(api_key="test-secret-key")
|
||||||
|
|
||||||
with patch('httpx.AsyncClient') as mock_client:
|
mock_client = Mock()
|
||||||
|
mock_client.is_closed = False
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.status_code = 200
|
mock_response.status_code = 200
|
||||||
mock_response.json.return_value = {
|
mock_response.json.return_value = {
|
||||||
|
|
@ -109,7 +121,9 @@ class TestPerplexityProvider:
|
||||||
}
|
}
|
||||||
|
|
||||||
mock_post = AsyncMock(return_value=mock_response)
|
mock_post = AsyncMock(return_value=mock_response)
|
||||||
mock_client.return_value.__aenter__.return_value.post = mock_post
|
mock_client.post = mock_post
|
||||||
|
|
||||||
|
provider._client = mock_client
|
||||||
|
|
||||||
await provider.get_response(sample_messages)
|
await provider.get_response(sample_messages)
|
||||||
|
|
||||||
|
|
@ -123,13 +137,15 @@ class TestPerplexityProvider:
|
||||||
"""Test handling of 401 Unauthorized error"""
|
"""Test handling of 401 Unauthorized error"""
|
||||||
provider = PerplexityProvider(api_key="invalid-key")
|
provider = PerplexityProvider(api_key="invalid-key")
|
||||||
|
|
||||||
with patch('httpx.AsyncClient') as mock_client:
|
mock_client = Mock()
|
||||||
|
mock_client.is_closed = False
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.status_code = 401
|
mock_response.status_code = 401
|
||||||
mock_response.text = "Unauthorized"
|
mock_response.text = "Unauthorized"
|
||||||
|
|
||||||
mock_post = AsyncMock(return_value=mock_response)
|
mock_client.post = AsyncMock(return_value=mock_response)
|
||||||
mock_client.return_value.__aenter__.return_value.post = mock_post
|
|
||||||
|
provider._client = mock_client
|
||||||
|
|
||||||
result = await provider.get_response(sample_messages)
|
result = await provider.get_response(sample_messages)
|
||||||
|
|
||||||
|
|
@ -142,13 +158,15 @@ class TestPerplexityProvider:
|
||||||
"""Test handling of 500 server error"""
|
"""Test handling of 500 server error"""
|
||||||
provider = PerplexityProvider(api_key="test-key")
|
provider = PerplexityProvider(api_key="test-key")
|
||||||
|
|
||||||
with patch('httpx.AsyncClient') as mock_client:
|
mock_client = Mock()
|
||||||
|
mock_client.is_closed = False
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.status_code = 500
|
mock_response.status_code = 500
|
||||||
mock_response.text = "Internal Server Error"
|
mock_response.text = "Internal Server Error"
|
||||||
|
|
||||||
mock_post = AsyncMock(return_value=mock_response)
|
mock_client.post = AsyncMock(return_value=mock_response)
|
||||||
mock_client.return_value.__aenter__.return_value.post = mock_post
|
|
||||||
|
provider._client = mock_client
|
||||||
|
|
||||||
result = await provider.get_response(sample_messages)
|
result = await provider.get_response(sample_messages)
|
||||||
|
|
||||||
|
|
@ -160,9 +178,11 @@ class TestPerplexityProvider:
|
||||||
"""Test handling of request timeout"""
|
"""Test handling of request timeout"""
|
||||||
provider = PerplexityProvider(api_key="test-key")
|
provider = PerplexityProvider(api_key="test-key")
|
||||||
|
|
||||||
with patch('httpx.AsyncClient') as mock_client:
|
mock_client = Mock()
|
||||||
mock_post = AsyncMock(side_effect=httpx.TimeoutException("Timeout"))
|
mock_client.is_closed = False
|
||||||
mock_client.return_value.__aenter__.return_value.post = mock_post
|
mock_client.post = AsyncMock(side_effect=httpx.TimeoutException("Timeout"))
|
||||||
|
|
||||||
|
provider._client = mock_client
|
||||||
|
|
||||||
result = await provider.get_response(sample_messages)
|
result = await provider.get_response(sample_messages)
|
||||||
|
|
||||||
|
|
@ -174,9 +194,11 @@ class TestPerplexityProvider:
|
||||||
"""Test handling of network errors"""
|
"""Test handling of network errors"""
|
||||||
provider = PerplexityProvider(api_key="test-key")
|
provider = PerplexityProvider(api_key="test-key")
|
||||||
|
|
||||||
with patch('httpx.AsyncClient') as mock_client:
|
mock_client = Mock()
|
||||||
mock_post = AsyncMock(side_effect=httpx.NetworkError("Network error"))
|
mock_client.is_closed = False
|
||||||
mock_client.return_value.__aenter__.return_value.post = mock_post
|
mock_client.post = AsyncMock(side_effect=httpx.NetworkError("Network error"))
|
||||||
|
|
||||||
|
provider._client = mock_client
|
||||||
|
|
||||||
result = await provider.get_response(sample_messages)
|
result = await provider.get_response(sample_messages)
|
||||||
|
|
||||||
|
|
@ -188,13 +210,15 @@ class TestPerplexityProvider:
|
||||||
"""Test that response time is tracked"""
|
"""Test that response time is tracked"""
|
||||||
provider = PerplexityProvider(api_key="test-key")
|
provider = PerplexityProvider(api_key="test-key")
|
||||||
|
|
||||||
with patch('httpx.AsyncClient') as mock_client:
|
mock_client = Mock()
|
||||||
|
mock_client.is_closed = False
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.status_code = 200
|
mock_response.status_code = 200
|
||||||
mock_response.json.return_value = mock_perplexity_response
|
mock_response.json.return_value = mock_perplexity_response
|
||||||
|
|
||||||
mock_post = AsyncMock(return_value=mock_response)
|
mock_client.post = AsyncMock(return_value=mock_response)
|
||||||
mock_client.return_value.__aenter__.return_value.post = mock_post
|
|
||||||
|
provider._client = mock_client
|
||||||
|
|
||||||
await provider.get_response(sample_messages)
|
await provider.get_response(sample_messages)
|
||||||
|
|
||||||
|
|
@ -282,13 +306,15 @@ class TestPerplexityProvider:
|
||||||
"""Test statistics after successful requests"""
|
"""Test statistics after successful requests"""
|
||||||
provider = PerplexityProvider(api_key="test-key")
|
provider = PerplexityProvider(api_key="test-key")
|
||||||
|
|
||||||
with patch('httpx.AsyncClient') as mock_client:
|
mock_client = Mock()
|
||||||
|
mock_client.is_closed = False
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.status_code = 200
|
mock_response.status_code = 200
|
||||||
mock_response.json.return_value = mock_perplexity_response
|
mock_response.json.return_value = mock_perplexity_response
|
||||||
|
|
||||||
mock_post = AsyncMock(return_value=mock_response)
|
mock_client.post = AsyncMock(return_value=mock_response)
|
||||||
mock_client.return_value.__aenter__.return_value.post = mock_post
|
|
||||||
|
provider._client = mock_client
|
||||||
|
|
||||||
# Make 2 successful requests
|
# Make 2 successful requests
|
||||||
await provider.get_response(sample_messages)
|
await provider.get_response(sample_messages)
|
||||||
|
|
@ -307,7 +333,9 @@ class TestPerplexityProvider:
|
||||||
"""Test statistics calculation with errors"""
|
"""Test statistics calculation with errors"""
|
||||||
provider = PerplexityProvider(api_key="test-key")
|
provider = PerplexityProvider(api_key="test-key")
|
||||||
|
|
||||||
with patch('httpx.AsyncClient') as mock_client:
|
mock_client = Mock()
|
||||||
|
mock_client.is_closed = False
|
||||||
|
|
||||||
# First request succeeds
|
# First request succeeds
|
||||||
mock_response_success = Mock()
|
mock_response_success = Mock()
|
||||||
mock_response_success.status_code = 200
|
mock_response_success.status_code = 200
|
||||||
|
|
@ -321,8 +349,9 @@ class TestPerplexityProvider:
|
||||||
mock_response_fail.status_code = 500
|
mock_response_fail.status_code = 500
|
||||||
mock_response_fail.text = "Error"
|
mock_response_fail.text = "Error"
|
||||||
|
|
||||||
mock_post = AsyncMock(side_effect=[mock_response_success, mock_response_fail])
|
mock_client.post = AsyncMock(side_effect=[mock_response_success, mock_response_fail])
|
||||||
mock_client.return_value.__aenter__.return_value.post = mock_post
|
|
||||||
|
provider._client = mock_client
|
||||||
|
|
||||||
await provider.get_response(sample_messages)
|
await provider.get_response(sample_messages)
|
||||||
await provider.get_response(sample_messages)
|
await provider.get_response(sample_messages)
|
||||||
|
|
@ -355,7 +384,8 @@ class TestPerplexityProvider:
|
||||||
"""Test handling response without usage data"""
|
"""Test handling response without usage data"""
|
||||||
provider = PerplexityProvider(api_key="test-key")
|
provider = PerplexityProvider(api_key="test-key")
|
||||||
|
|
||||||
with patch('httpx.AsyncClient') as mock_client:
|
mock_client = Mock()
|
||||||
|
mock_client.is_closed = False
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.status_code = 200
|
mock_response.status_code = 200
|
||||||
mock_response.json.return_value = {
|
mock_response.json.return_value = {
|
||||||
|
|
@ -363,8 +393,9 @@ class TestPerplexityProvider:
|
||||||
# No usage field
|
# No usage field
|
||||||
}
|
}
|
||||||
|
|
||||||
mock_post = AsyncMock(return_value=mock_response)
|
mock_client.post = AsyncMock(return_value=mock_response)
|
||||||
mock_client.return_value.__aenter__.return_value.post = mock_post
|
|
||||||
|
provider._client = mock_client
|
||||||
|
|
||||||
result = await provider.get_response(sample_messages)
|
result = await provider.get_response(sample_messages)
|
||||||
|
|
||||||
|
|
@ -376,7 +407,8 @@ class TestPerplexityProvider:
|
||||||
"""Test that correct API endpoint is used"""
|
"""Test that correct API endpoint is used"""
|
||||||
provider = PerplexityProvider(api_key="test-key")
|
provider = PerplexityProvider(api_key="test-key")
|
||||||
|
|
||||||
with patch('httpx.AsyncClient') as mock_client:
|
mock_client = Mock()
|
||||||
|
mock_client.is_closed = False
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.status_code = 200
|
mock_response.status_code = 200
|
||||||
mock_response.json.return_value = {
|
mock_response.json.return_value = {
|
||||||
|
|
@ -385,7 +417,9 @@ class TestPerplexityProvider:
|
||||||
}
|
}
|
||||||
|
|
||||||
mock_post = AsyncMock(return_value=mock_response)
|
mock_post = AsyncMock(return_value=mock_response)
|
||||||
mock_client.return_value.__aenter__.return_value.post = mock_post
|
mock_client.post = mock_post
|
||||||
|
|
||||||
|
provider._client = mock_client
|
||||||
|
|
||||||
await provider.get_response(sample_messages)
|
await provider.get_response(sample_messages)
|
||||||
|
|
||||||
|
|
@ -399,6 +433,8 @@ class TestPerplexityProvider:
|
||||||
provider = PerplexityProvider(api_key="test-key")
|
provider = PerplexityProvider(api_key="test-key")
|
||||||
|
|
||||||
with patch('httpx.AsyncClient') as mock_client_class:
|
with patch('httpx.AsyncClient') as mock_client_class:
|
||||||
|
mock_client_instance = Mock()
|
||||||
|
mock_client_instance.is_closed = False
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.status_code = 200
|
mock_response.status_code = 200
|
||||||
mock_response.json.return_value = {
|
mock_response.json.return_value = {
|
||||||
|
|
@ -406,11 +442,96 @@ class TestPerplexityProvider:
|
||||||
"usage": {"total_tokens": 10}
|
"usage": {"total_tokens": 10}
|
||||||
}
|
}
|
||||||
|
|
||||||
mock_post = AsyncMock(return_value=mock_response)
|
mock_client_instance.post = AsyncMock(return_value=mock_response)
|
||||||
mock_client_class.return_value.__aenter__.return_value.post = mock_post
|
mock_client_class.return_value = mock_client_instance
|
||||||
|
|
||||||
await provider.get_response(sample_messages)
|
await provider.get_response(sample_messages)
|
||||||
|
|
||||||
# Check that AsyncClient was instantiated with timeout
|
# Check that AsyncClient was instantiated with timeout
|
||||||
call_args = mock_client_class.call_args
|
call_args = mock_client_class.call_args
|
||||||
assert call_args[1]["timeout"] == 30.0
|
assert call_args[1]["timeout"] == 30.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_client_creates_client_when_none(self):
|
||||||
|
"""Test that _get_client creates a new client when none exists"""
|
||||||
|
provider = PerplexityProvider(api_key="test-key")
|
||||||
|
|
||||||
|
with patch('httpx.AsyncClient') as mock_client_class:
|
||||||
|
mock_client_instance = Mock()
|
||||||
|
mock_client_instance.is_closed = False
|
||||||
|
mock_client_class.return_value = mock_client_instance
|
||||||
|
|
||||||
|
client = await provider._get_client()
|
||||||
|
|
||||||
|
assert client is mock_client_instance
|
||||||
|
mock_client_class.assert_called_once_with(timeout=30.0)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_client_reuses_existing_client(self):
|
||||||
|
"""Test that _get_client reuses an existing open client"""
|
||||||
|
provider = PerplexityProvider(api_key="test-key")
|
||||||
|
|
||||||
|
mock_client = Mock()
|
||||||
|
mock_client.is_closed = False
|
||||||
|
provider._client = mock_client
|
||||||
|
|
||||||
|
with patch('httpx.AsyncClient') as mock_client_class:
|
||||||
|
client = await provider._get_client()
|
||||||
|
|
||||||
|
assert client is mock_client
|
||||||
|
mock_client_class.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_client_recreates_closed_client(self):
|
||||||
|
"""Test that _get_client creates a new client when existing one is closed"""
|
||||||
|
provider = PerplexityProvider(api_key="test-key")
|
||||||
|
|
||||||
|
old_client = Mock()
|
||||||
|
old_client.is_closed = True
|
||||||
|
provider._client = old_client
|
||||||
|
|
||||||
|
with patch('httpx.AsyncClient') as mock_client_class:
|
||||||
|
new_client = Mock()
|
||||||
|
new_client.is_closed = False
|
||||||
|
mock_client_class.return_value = new_client
|
||||||
|
|
||||||
|
client = await provider._get_client()
|
||||||
|
|
||||||
|
assert client is new_client
|
||||||
|
mock_client_class.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_close_closes_client(self):
|
||||||
|
"""Test that close() properly closes the HTTP client"""
|
||||||
|
provider = PerplexityProvider(api_key="test-key")
|
||||||
|
|
||||||
|
mock_client = Mock()
|
||||||
|
mock_client.is_closed = False
|
||||||
|
mock_client.aclose = AsyncMock()
|
||||||
|
provider._client = mock_client
|
||||||
|
|
||||||
|
await provider.close()
|
||||||
|
|
||||||
|
mock_client.aclose.assert_called_once()
|
||||||
|
assert provider._client is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_close_handles_none_client(self):
|
||||||
|
"""Test that close() handles case when client is None"""
|
||||||
|
provider = PerplexityProvider(api_key="test-key")
|
||||||
|
provider._client = None
|
||||||
|
|
||||||
|
# Should not raise any errors
|
||||||
|
await provider.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_close_handles_already_closed_client(self):
|
||||||
|
"""Test that close() handles already closed client"""
|
||||||
|
provider = PerplexityProvider(api_key="test-key")
|
||||||
|
|
||||||
|
mock_client = Mock()
|
||||||
|
mock_client.is_closed = True
|
||||||
|
provider._client = mock_client
|
||||||
|
|
||||||
|
# Should not raise any errors
|
||||||
|
await provider.close()
|
||||||
|
|
|
||||||
|
|
@ -22,14 +22,17 @@ class TestFullWorkflow:
|
||||||
detector = MentionDetector(config.bot_name)
|
detector = MentionDetector(config.bot_name)
|
||||||
ai = PerplexityProvider(api_key=config.perplexity_key, model=config.model)
|
ai = PerplexityProvider(api_key=config.perplexity_key, model=config.model)
|
||||||
|
|
||||||
# Mock API response
|
# Mock API client for lazy initialization
|
||||||
with patch('httpx.AsyncClient') as mock_client:
|
mock_client = Mock()
|
||||||
|
mock_client.is_closed = False
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.status_code = 200
|
mock_response.status_code = 200
|
||||||
mock_response.json.return_value = mock_perplexity_response
|
mock_response.json.return_value = mock_perplexity_response
|
||||||
|
|
||||||
mock_post = AsyncMock(return_value=mock_response)
|
mock_client.post = AsyncMock(return_value=mock_response)
|
||||||
mock_client.return_value.__aenter__.return_value.post = mock_post
|
|
||||||
|
# Inject the mock client
|
||||||
|
ai._client = mock_client
|
||||||
|
|
||||||
# Simulate user message
|
# Simulate user message
|
||||||
username = "testuser"
|
username = "testuser"
|
||||||
|
|
|
||||||
46
utils.py
46
utils.py
|
|
@ -20,8 +20,22 @@ class MentionDetector:
|
||||||
# Create patterns for various mention formats
|
# Create patterns for various mention formats
|
||||||
# Include bot name and all nicknames
|
# Include bot name and all nicknames
|
||||||
all_names = [bot_name] + self.nicknames
|
all_names = [bot_name] + self.nicknames
|
||||||
self.patterns = []
|
|
||||||
|
|
||||||
|
# Build a single combined pattern for efficient matching (uses alternation)
|
||||||
|
name_alternatives = "|".join(re.escape(name) for name in all_names)
|
||||||
|
|
||||||
|
# Combined pattern that matches any mention format
|
||||||
|
# This is more efficient than checking multiple patterns separately
|
||||||
|
combined_mention_pattern = (
|
||||||
|
rf"(?:@(?:{name_alternatives})\b)|" # @name
|
||||||
|
rf"(?:\b(?:{name_alternatives})[:!?.,])|" # name: name! etc.
|
||||||
|
rf"(?:^(?:{name_alternatives})\b)|" # name at start
|
||||||
|
rf"(?:\b(?:{name_alternatives})\b)" # name anywhere
|
||||||
|
)
|
||||||
|
self._mention_pattern = re.compile(combined_mention_pattern, re.IGNORECASE)
|
||||||
|
|
||||||
|
# Keep individual patterns list for backward compatibility (tests may use it)
|
||||||
|
self.patterns = []
|
||||||
for name in all_names:
|
for name in all_names:
|
||||||
self.patterns.extend([
|
self.patterns.extend([
|
||||||
rf"@{name}\b", # @name (with word boundary)
|
rf"@{name}\b", # @name (with word boundary)
|
||||||
|
|
@ -29,12 +43,21 @@ class MentionDetector:
|
||||||
rf"^{name}\b", # name at start of message
|
rf"^{name}\b", # name at start of message
|
||||||
rf"\b{name}\b", # name anywhere as whole word
|
rf"\b{name}\b", # name anywhere as whole word
|
||||||
])
|
])
|
||||||
|
|
||||||
# Case-insensitive compilation
|
|
||||||
self.compiled_patterns = [
|
self.compiled_patterns = [
|
||||||
re.compile(pattern, re.IGNORECASE) for pattern in self.patterns
|
re.compile(pattern, re.IGNORECASE) for pattern in self.patterns
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Pre-compile extraction patterns for each name (more efficient than recompiling each time)
|
||||||
|
self._extraction_patterns = {}
|
||||||
|
for name in all_names:
|
||||||
|
escaped_name = re.escape(name)
|
||||||
|
self._extraction_patterns[name.lower()] = {
|
||||||
|
'at_start': re.compile(rf"^@{escaped_name}\b[,:]?\s*", re.IGNORECASE),
|
||||||
|
'name_start': re.compile(rf"^{escaped_name}\b[,:]?\s*", re.IGNORECASE),
|
||||||
|
'name_end': re.compile(rf"\s*\b{escaped_name}[,!?.]?\s*$", re.IGNORECASE),
|
||||||
|
'name_middle': re.compile(rf"\s*\b{escaped_name}[,:!?]\s*", re.IGNORECASE),
|
||||||
|
}
|
||||||
|
|
||||||
# Patterns for ambiguous greetings (might be directed at bot)
|
# Patterns for ambiguous greetings (might be directed at bot)
|
||||||
self.greeting_patterns = [
|
self.greeting_patterns = [
|
||||||
r"^(hi|hey|hallo|hello|servus|moin)(\s|$|\W)",
|
r"^(hi|hey|hallo|hello|servus|moin)(\s|$|\W)",
|
||||||
|
|
@ -78,10 +101,8 @@ class MentionDetector:
|
||||||
if not message:
|
if not message:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
for pattern in self.compiled_patterns:
|
# Use the optimized single combined pattern for faster matching
|
||||||
if pattern.search(message):
|
return bool(self._mention_pattern.search(message))
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def is_ambiguous_greeting(self, message):
|
def is_ambiguous_greeting(self, message):
|
||||||
"""
|
"""
|
||||||
|
|
@ -122,18 +143,21 @@ class MentionDetector:
|
||||||
content = message
|
content = message
|
||||||
all_names = [self.bot_name] + self.nicknames
|
all_names = [self.bot_name] + self.nicknames
|
||||||
|
|
||||||
|
# Use pre-compiled patterns for better performance
|
||||||
for name in all_names:
|
for name in all_names:
|
||||||
|
patterns = self._extraction_patterns.get(name.lower())
|
||||||
|
if patterns:
|
||||||
# Remove @mention at start
|
# Remove @mention at start
|
||||||
content = re.sub(rf"^@{name}\b[,:]?\s*", "", content, flags=re.IGNORECASE)
|
content = patterns['at_start'].sub("", content)
|
||||||
|
|
||||||
# Remove name at start with optional punctuation
|
# Remove name at start with optional punctuation
|
||||||
content = re.sub(rf"^{name}\b[,:]?\s*", "", content, flags=re.IGNORECASE)
|
content = patterns['name_start'].sub("", content)
|
||||||
|
|
||||||
# Remove name at end with optional punctuation
|
# Remove name at end with optional punctuation
|
||||||
content = re.sub(rf"\s*\b{name}[,!?.]?\s*$", "", content, flags=re.IGNORECASE)
|
content = patterns['name_end'].sub("", content)
|
||||||
|
|
||||||
# Remove name in middle with punctuation
|
# Remove name in middle with punctuation
|
||||||
content = re.sub(rf"\s*\b{name}[,:!?]\s*", " ", content, flags=re.IGNORECASE)
|
content = patterns['name_middle'].sub(" ", content)
|
||||||
|
|
||||||
return content.strip()
|
return content.strip()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue