Implement performance improvements for memory caching, HTTP client reuse, and regex optimization
Co-authored-by: Kenearos <86194771+Kenearos@users.noreply.github.com>
This commit is contained in:
parent
860e4d5027
commit
b72cd9db1c
5 changed files with 451 additions and 246 deletions
|
|
@ -3,7 +3,7 @@ Tests for PerplexityProvider AI API class
|
|||
"""
|
||||
import pytest
|
||||
import httpx
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from unittest.mock import AsyncMock, Mock, patch, PropertyMock
|
||||
from ai_provider import PerplexityProvider
|
||||
|
||||
|
||||
|
|
@ -42,28 +42,36 @@ class TestPerplexityProvider:
|
|||
assert provider.total_errors == 0
|
||||
assert provider.last_response_time == 0
|
||||
|
||||
def test_init_client_starts_as_none(self):
|
||||
"""Test that HTTP client starts as None (lazy initialization)"""
|
||||
provider = PerplexityProvider(api_key="test-key")
|
||||
assert provider._client is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_response_success(self, sample_messages, mock_perplexity_response):
|
||||
"""Test successful API response"""
|
||||
provider = PerplexityProvider(api_key="test-key")
|
||||
|
||||
# Mock the HTTP client
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_perplexity_response
|
||||
# Create a mock client with is_closed property
|
||||
mock_client = Mock()
|
||||
mock_client.is_closed = False
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_perplexity_response
|
||||
|
||||
mock_post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value.post = mock_post
|
||||
mock_client.post = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Call the method
|
||||
result = await provider.get_response(sample_messages)
|
||||
# Inject the mock client
|
||||
provider._client = mock_client
|
||||
|
||||
# Verify result
|
||||
assert result == "This is a test response from the AI."
|
||||
assert provider.total_requests == 1
|
||||
assert provider.total_tokens == 70 # From mock response
|
||||
assert provider.total_errors == 0
|
||||
# Call the method
|
||||
result = await provider.get_response(sample_messages)
|
||||
|
||||
# Verify result
|
||||
assert result == "This is a test response from the AI."
|
||||
assert provider.total_requests == 1
|
||||
assert provider.total_tokens == 70 # From mock response
|
||||
assert provider.total_errors == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_response_constructs_correct_payload(self, sample_messages):
|
||||
|
|
@ -75,130 +83,146 @@ class TestPerplexityProvider:
|
|||
temperature=0.7
|
||||
)
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"choices": [{"message": {"content": "test"}}],
|
||||
"usage": {"total_tokens": 10}
|
||||
}
|
||||
mock_client = Mock()
|
||||
mock_client.is_closed = False
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"choices": [{"message": {"content": "test"}}],
|
||||
"usage": {"total_tokens": 10}
|
||||
}
|
||||
|
||||
mock_post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value.post = mock_post
|
||||
mock_post = AsyncMock(return_value=mock_response)
|
||||
mock_client.post = mock_post
|
||||
|
||||
await provider.get_response(sample_messages)
|
||||
provider._client = mock_client
|
||||
|
||||
# Verify the call was made with correct parameters
|
||||
call_args = mock_post.call_args
|
||||
assert call_args[1]["json"]["model"] == "sonar-pro"
|
||||
assert call_args[1]["json"]["messages"] == sample_messages
|
||||
assert call_args[1]["json"]["max_tokens"] == 450
|
||||
assert call_args[1]["json"]["temperature"] == 0.7
|
||||
await provider.get_response(sample_messages)
|
||||
|
||||
# Verify the call was made with correct parameters
|
||||
call_args = mock_post.call_args
|
||||
assert call_args[1]["json"]["model"] == "sonar-pro"
|
||||
assert call_args[1]["json"]["messages"] == sample_messages
|
||||
assert call_args[1]["json"]["max_tokens"] == 450
|
||||
assert call_args[1]["json"]["temperature"] == 0.7
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_response_includes_auth_header(self, sample_messages):
|
||||
"""Test that Authorization header is included"""
|
||||
provider = PerplexityProvider(api_key="test-secret-key")
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"choices": [{"message": {"content": "test"}}],
|
||||
"usage": {"total_tokens": 10}
|
||||
}
|
||||
mock_client = Mock()
|
||||
mock_client.is_closed = False
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"choices": [{"message": {"content": "test"}}],
|
||||
"usage": {"total_tokens": 10}
|
||||
}
|
||||
|
||||
mock_post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value.post = mock_post
|
||||
mock_post = AsyncMock(return_value=mock_response)
|
||||
mock_client.post = mock_post
|
||||
|
||||
await provider.get_response(sample_messages)
|
||||
provider._client = mock_client
|
||||
|
||||
call_args = mock_post.call_args
|
||||
headers = call_args[1]["headers"]
|
||||
assert headers["Authorization"] == "Bearer test-secret-key"
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
await provider.get_response(sample_messages)
|
||||
|
||||
call_args = mock_post.call_args
|
||||
headers = call_args[1]["headers"]
|
||||
assert headers["Authorization"] == "Bearer test-secret-key"
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_response_handles_401_error(self, sample_messages):
|
||||
"""Test handling of 401 Unauthorized error"""
|
||||
provider = PerplexityProvider(api_key="invalid-key")
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 401
|
||||
mock_response.text = "Unauthorized"
|
||||
mock_client = Mock()
|
||||
mock_client.is_closed = False
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 401
|
||||
mock_response.text = "Unauthorized"
|
||||
|
||||
mock_post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value.post = mock_post
|
||||
mock_client.post = AsyncMock(return_value=mock_response)
|
||||
|
||||
result = await provider.get_response(sample_messages)
|
||||
provider._client = mock_client
|
||||
|
||||
assert result is None
|
||||
assert provider.total_errors == 1
|
||||
assert provider.total_requests == 0 # Failed requests don't count
|
||||
result = await provider.get_response(sample_messages)
|
||||
|
||||
assert result is None
|
||||
assert provider.total_errors == 1
|
||||
assert provider.total_requests == 0 # Failed requests don't count
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_response_handles_500_error(self, sample_messages):
|
||||
"""Test handling of 500 server error"""
|
||||
provider = PerplexityProvider(api_key="test-key")
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 500
|
||||
mock_response.text = "Internal Server Error"
|
||||
mock_client = Mock()
|
||||
mock_client.is_closed = False
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 500
|
||||
mock_response.text = "Internal Server Error"
|
||||
|
||||
mock_post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value.post = mock_post
|
||||
mock_client.post = AsyncMock(return_value=mock_response)
|
||||
|
||||
result = await provider.get_response(sample_messages)
|
||||
provider._client = mock_client
|
||||
|
||||
assert result is None
|
||||
assert provider.total_errors == 1
|
||||
result = await provider.get_response(sample_messages)
|
||||
|
||||
assert result is None
|
||||
assert provider.total_errors == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_response_handles_timeout(self, sample_messages):
|
||||
"""Test handling of request timeout"""
|
||||
provider = PerplexityProvider(api_key="test-key")
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_post = AsyncMock(side_effect=httpx.TimeoutException("Timeout"))
|
||||
mock_client.return_value.__aenter__.return_value.post = mock_post
|
||||
mock_client = Mock()
|
||||
mock_client.is_closed = False
|
||||
mock_client.post = AsyncMock(side_effect=httpx.TimeoutException("Timeout"))
|
||||
|
||||
result = await provider.get_response(sample_messages)
|
||||
provider._client = mock_client
|
||||
|
||||
assert result is None
|
||||
assert provider.total_errors == 1
|
||||
result = await provider.get_response(sample_messages)
|
||||
|
||||
assert result is None
|
||||
assert provider.total_errors == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_response_handles_network_error(self, sample_messages):
|
||||
"""Test handling of network errors"""
|
||||
provider = PerplexityProvider(api_key="test-key")
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_post = AsyncMock(side_effect=httpx.NetworkError("Network error"))
|
||||
mock_client.return_value.__aenter__.return_value.post = mock_post
|
||||
mock_client = Mock()
|
||||
mock_client.is_closed = False
|
||||
mock_client.post = AsyncMock(side_effect=httpx.NetworkError("Network error"))
|
||||
|
||||
result = await provider.get_response(sample_messages)
|
||||
provider._client = mock_client
|
||||
|
||||
assert result is None
|
||||
assert provider.total_errors == 1
|
||||
result = await provider.get_response(sample_messages)
|
||||
|
||||
assert result is None
|
||||
assert provider.total_errors == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_response_tracks_response_time(self, sample_messages, mock_perplexity_response):
|
||||
"""Test that response time is tracked"""
|
||||
provider = PerplexityProvider(api_key="test-key")
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_perplexity_response
|
||||
mock_client = Mock()
|
||||
mock_client.is_closed = False
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_perplexity_response
|
||||
|
||||
mock_post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value.post = mock_post
|
||||
mock_client.post = AsyncMock(return_value=mock_response)
|
||||
|
||||
await provider.get_response(sample_messages)
|
||||
provider._client = mock_client
|
||||
|
||||
assert provider.last_response_time > 0
|
||||
await provider.get_response(sample_messages)
|
||||
|
||||
assert provider.last_response_time > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_api_key_success(self):
|
||||
|
|
@ -282,56 +306,61 @@ class TestPerplexityProvider:
|
|||
"""Test statistics after successful requests"""
|
||||
provider = PerplexityProvider(api_key="test-key")
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_perplexity_response
|
||||
mock_client = Mock()
|
||||
mock_client.is_closed = False
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_perplexity_response
|
||||
|
||||
mock_post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value.post = mock_post
|
||||
mock_client.post = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Make 2 successful requests
|
||||
await provider.get_response(sample_messages)
|
||||
await provider.get_response(sample_messages)
|
||||
provider._client = mock_client
|
||||
|
||||
stats = provider.get_statistics()
|
||||
# Make 2 successful requests
|
||||
await provider.get_response(sample_messages)
|
||||
await provider.get_response(sample_messages)
|
||||
|
||||
assert stats["total_requests"] == 2
|
||||
assert stats["total_tokens"] == 140 # 70 * 2
|
||||
assert stats["total_errors"] == 0
|
||||
assert stats["success_rate"] == 100
|
||||
assert stats["estimated_cost"] > 0
|
||||
stats = provider.get_statistics()
|
||||
|
||||
assert stats["total_requests"] == 2
|
||||
assert stats["total_tokens"] == 140 # 70 * 2
|
||||
assert stats["total_errors"] == 0
|
||||
assert stats["success_rate"] == 100
|
||||
assert stats["estimated_cost"] > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_statistics_with_errors(self, sample_messages):
|
||||
"""Test statistics calculation with errors"""
|
||||
provider = PerplexityProvider(api_key="test-key")
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
# First request succeeds
|
||||
mock_response_success = Mock()
|
||||
mock_response_success.status_code = 200
|
||||
mock_response_success.json.return_value = {
|
||||
"choices": [{"message": {"content": "test"}}],
|
||||
"usage": {"total_tokens": 50}
|
||||
}
|
||||
mock_client = Mock()
|
||||
mock_client.is_closed = False
|
||||
|
||||
# Second request fails
|
||||
mock_response_fail = Mock()
|
||||
mock_response_fail.status_code = 500
|
||||
mock_response_fail.text = "Error"
|
||||
# First request succeeds
|
||||
mock_response_success = Mock()
|
||||
mock_response_success.status_code = 200
|
||||
mock_response_success.json.return_value = {
|
||||
"choices": [{"message": {"content": "test"}}],
|
||||
"usage": {"total_tokens": 50}
|
||||
}
|
||||
|
||||
mock_post = AsyncMock(side_effect=[mock_response_success, mock_response_fail])
|
||||
mock_client.return_value.__aenter__.return_value.post = mock_post
|
||||
# Second request fails
|
||||
mock_response_fail = Mock()
|
||||
mock_response_fail.status_code = 500
|
||||
mock_response_fail.text = "Error"
|
||||
|
||||
await provider.get_response(sample_messages)
|
||||
await provider.get_response(sample_messages)
|
||||
mock_client.post = AsyncMock(side_effect=[mock_response_success, mock_response_fail])
|
||||
|
||||
stats = provider.get_statistics()
|
||||
provider._client = mock_client
|
||||
|
||||
assert stats["total_requests"] == 1
|
||||
assert stats["total_errors"] == 1
|
||||
assert stats["success_rate"] == 0 # 0 out of 1 completed request succeeded
|
||||
await provider.get_response(sample_messages)
|
||||
await provider.get_response(sample_messages)
|
||||
|
||||
stats = provider.get_statistics()
|
||||
|
||||
assert stats["total_requests"] == 1
|
||||
assert stats["total_errors"] == 1
|
||||
assert stats["success_rate"] == 0 # 0 out of 1 completed request succeeded
|
||||
|
||||
def test_reset_statistics(self):
|
||||
"""Test resetting statistics"""
|
||||
|
|
@ -355,43 +384,48 @@ class TestPerplexityProvider:
|
|||
"""Test handling response without usage data"""
|
||||
provider = PerplexityProvider(api_key="test-key")
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"choices": [{"message": {"content": "test response"}}]
|
||||
# No usage field
|
||||
}
|
||||
mock_client = Mock()
|
||||
mock_client.is_closed = False
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"choices": [{"message": {"content": "test response"}}]
|
||||
# No usage field
|
||||
}
|
||||
|
||||
mock_post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value.post = mock_post
|
||||
mock_client.post = AsyncMock(return_value=mock_response)
|
||||
|
||||
result = await provider.get_response(sample_messages)
|
||||
provider._client = mock_client
|
||||
|
||||
assert result == "test response"
|
||||
assert provider.total_tokens == 0 # Should default to 0
|
||||
result = await provider.get_response(sample_messages)
|
||||
|
||||
assert result == "test response"
|
||||
assert provider.total_tokens == 0 # Should default to 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_response_uses_correct_endpoint(self, sample_messages):
|
||||
"""Test that correct API endpoint is used"""
|
||||
provider = PerplexityProvider(api_key="test-key")
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"choices": [{"message": {"content": "test"}}],
|
||||
"usage": {"total_tokens": 10}
|
||||
}
|
||||
mock_client = Mock()
|
||||
mock_client.is_closed = False
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"choices": [{"message": {"content": "test"}}],
|
||||
"usage": {"total_tokens": 10}
|
||||
}
|
||||
|
||||
mock_post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value.post = mock_post
|
||||
mock_post = AsyncMock(return_value=mock_response)
|
||||
mock_client.post = mock_post
|
||||
|
||||
await provider.get_response(sample_messages)
|
||||
provider._client = mock_client
|
||||
|
||||
call_args = mock_post.call_args
|
||||
# First positional argument should be the URL
|
||||
assert call_args[0][0] == "https://api.perplexity.ai/chat/completions"
|
||||
await provider.get_response(sample_messages)
|
||||
|
||||
call_args = mock_post.call_args
|
||||
# First positional argument should be the URL
|
||||
assert call_args[0][0] == "https://api.perplexity.ai/chat/completions"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_response_timeout_configuration(self, sample_messages):
|
||||
|
|
@ -399,6 +433,8 @@ class TestPerplexityProvider:
|
|||
provider = PerplexityProvider(api_key="test-key")
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client_class:
|
||||
mock_client_instance = Mock()
|
||||
mock_client_instance.is_closed = False
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
|
|
@ -406,11 +442,96 @@ class TestPerplexityProvider:
|
|||
"usage": {"total_tokens": 10}
|
||||
}
|
||||
|
||||
mock_post = AsyncMock(return_value=mock_response)
|
||||
mock_client_class.return_value.__aenter__.return_value.post = mock_post
|
||||
mock_client_instance.post = AsyncMock(return_value=mock_response)
|
||||
mock_client_class.return_value = mock_client_instance
|
||||
|
||||
await provider.get_response(sample_messages)
|
||||
|
||||
# Check that AsyncClient was instantiated with timeout
|
||||
call_args = mock_client_class.call_args
|
||||
assert call_args[1]["timeout"] == 30.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_client_creates_client_when_none(self):
|
||||
"""Test that _get_client creates a new client when none exists"""
|
||||
provider = PerplexityProvider(api_key="test-key")
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client_class:
|
||||
mock_client_instance = Mock()
|
||||
mock_client_instance.is_closed = False
|
||||
mock_client_class.return_value = mock_client_instance
|
||||
|
||||
client = await provider._get_client()
|
||||
|
||||
assert client is mock_client_instance
|
||||
mock_client_class.assert_called_once_with(timeout=30.0)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_client_reuses_existing_client(self):
|
||||
"""Test that _get_client reuses an existing open client"""
|
||||
provider = PerplexityProvider(api_key="test-key")
|
||||
|
||||
mock_client = Mock()
|
||||
mock_client.is_closed = False
|
||||
provider._client = mock_client
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client_class:
|
||||
client = await provider._get_client()
|
||||
|
||||
assert client is mock_client
|
||||
mock_client_class.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_client_recreates_closed_client(self):
|
||||
"""Test that _get_client creates a new client when existing one is closed"""
|
||||
provider = PerplexityProvider(api_key="test-key")
|
||||
|
||||
old_client = Mock()
|
||||
old_client.is_closed = True
|
||||
provider._client = old_client
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client_class:
|
||||
new_client = Mock()
|
||||
new_client.is_closed = False
|
||||
mock_client_class.return_value = new_client
|
||||
|
||||
client = await provider._get_client()
|
||||
|
||||
assert client is new_client
|
||||
mock_client_class.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_closes_client(self):
|
||||
"""Test that close() properly closes the HTTP client"""
|
||||
provider = PerplexityProvider(api_key="test-key")
|
||||
|
||||
mock_client = Mock()
|
||||
mock_client.is_closed = False
|
||||
mock_client.aclose = AsyncMock()
|
||||
provider._client = mock_client
|
||||
|
||||
await provider.close()
|
||||
|
||||
mock_client.aclose.assert_called_once()
|
||||
assert provider._client is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_handles_none_client(self):
|
||||
"""Test that close() handles case when client is None"""
|
||||
provider = PerplexityProvider(api_key="test-key")
|
||||
provider._client = None
|
||||
|
||||
# Should not raise any errors
|
||||
await provider.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_handles_already_closed_client(self):
|
||||
"""Test that close() handles already closed client"""
|
||||
provider = PerplexityProvider(api_key="test-key")
|
||||
|
||||
mock_client = Mock()
|
||||
mock_client.is_closed = True
|
||||
provider._client = mock_client
|
||||
|
||||
# Should not raise any errors
|
||||
await provider.close()
|
||||
|
|
|
|||
|
|
@ -22,48 +22,51 @@ class TestFullWorkflow:
|
|||
detector = MentionDetector(config.bot_name)
|
||||
ai = PerplexityProvider(api_key=config.perplexity_key, model=config.model)
|
||||
|
||||
# Mock API response
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_perplexity_response
|
||||
# Mock API client for lazy initialization
|
||||
mock_client = Mock()
|
||||
mock_client.is_closed = False
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_perplexity_response
|
||||
|
||||
mock_post = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value.post = mock_post
|
||||
mock_client.post = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Simulate user message
|
||||
username = "testuser"
|
||||
user_message = "@TestBot what's the weather?"
|
||||
# Inject the mock client
|
||||
ai._client = mock_client
|
||||
|
||||
# 1. Detect mention
|
||||
assert detector.is_mentioned(user_message)
|
||||
# Simulate user message
|
||||
username = "testuser"
|
||||
user_message = "@TestBot what's the weather?"
|
||||
|
||||
# 2. Extract content
|
||||
content = detector.extract_content(user_message)
|
||||
assert content == "what's the weather?"
|
||||
# 1. Detect mention
|
||||
assert detector.is_mentioned(user_message)
|
||||
|
||||
# 3. Load history
|
||||
history = memory.get_user_history(username, limit=5)
|
||||
assert len(history) == 0 # First message
|
||||
# 2. Extract content
|
||||
content = detector.extract_content(user_message)
|
||||
assert content == "what's the weather?"
|
||||
|
||||
# 4. Build messages for API
|
||||
messages = [{"role": "system", "content": config.get_system_prompt()}]
|
||||
messages.extend(memory.format_for_prompt(history))
|
||||
messages.append({"role": "user", "content": content})
|
||||
# 3. Load history
|
||||
history = memory.get_user_history(username, limit=5)
|
||||
assert len(history) == 0 # First message
|
||||
|
||||
# 5. Get AI response
|
||||
response = await ai.get_response(messages)
|
||||
assert response == "This is a test response from the AI."
|
||||
# 4. Build messages for API
|
||||
messages = [{"role": "system", "content": config.get_system_prompt()}]
|
||||
messages.extend(memory.format_for_prompt(history))
|
||||
messages.append({"role": "user", "content": content})
|
||||
|
||||
# 6. Save to memory
|
||||
memory.add_message(username, "user", content)
|
||||
memory.add_message(username, "assistant", response)
|
||||
# 5. Get AI response
|
||||
response = await ai.get_response(messages)
|
||||
assert response == "This is a test response from the AI."
|
||||
|
||||
# 7. Verify history was saved
|
||||
saved_history = memory.get_user_history(username)
|
||||
assert len(saved_history) == 2
|
||||
assert saved_history[0]["content"] == content
|
||||
assert saved_history[1]["content"] == response
|
||||
# 6. Save to memory
|
||||
memory.add_message(username, "user", content)
|
||||
memory.add_message(username, "assistant", response)
|
||||
|
||||
# 7. Verify history was saved
|
||||
saved_history = memory.get_user_history(username)
|
||||
assert len(saved_history) == 2
|
||||
assert saved_history[0]["content"] == content
|
||||
assert saved_history[1]["content"] == response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conversation_context_preserved(self, temp_dir, mock_env_file):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue