486 lines
16 KiB
Python
486 lines
16 KiB
Python
"""
|
|
OAuth User Pool Management for Load Testing.
|
|
|
|
Manages multiple OAuth-authenticated users for realistic multi-user load testing scenarios.
|
|
"""
|
|
|
|
import logging
|
|
import secrets
|
|
import string
|
|
import time
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
from urllib.parse import quote
|
|
|
|
import anyio
|
|
import httpx
|
|
from mcp import ClientSession
|
|
from mcp.client.streamable_http import streamablehttp_client
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class UserConfig:
|
|
"""Configuration for a single test user."""
|
|
|
|
username: str
|
|
password: str
|
|
display_name: str
|
|
email: str
|
|
groups: list[str]
|
|
|
|
|
|
@dataclass
|
|
class UserProfile:
|
|
"""Profile for an OAuth-authenticated user."""
|
|
|
|
username: str
|
|
password: str
|
|
token: str
|
|
session: ClientSession | None = None
|
|
streamable_context: Any | None = None # Store for proper cleanup
|
|
operation_count: int = 0
|
|
error_count: int = 0
|
|
|
|
|
|
class OAuthUserPool:
|
|
"""
|
|
Manages a pool of OAuth-authenticated users for load testing.
|
|
|
|
Handles token acquisition, session management, and user lifecycle.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
admin_client: Any, # NextcloudClient with admin credentials
|
|
client_id: str,
|
|
client_secret: str,
|
|
callback_url: str,
|
|
token_endpoint: str,
|
|
authorization_endpoint: str,
|
|
):
|
|
self.admin_client = admin_client # For user management
|
|
self.nextcloud_host = str(admin_client._client.base_url)
|
|
self.client_id = client_id
|
|
self.client_secret = client_secret
|
|
self.callback_url = callback_url
|
|
self.token_endpoint = token_endpoint
|
|
self.authorization_endpoint = authorization_endpoint
|
|
self.users: dict[str, UserProfile] = {}
|
|
self._http_client: httpx.AsyncClient | None = None
|
|
|
|
async def __aenter__(self):
|
|
"""Initialize HTTP client."""
|
|
self._http_client = httpx.AsyncClient(verify=False, timeout=30.0)
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
"""Cleanup HTTP client."""
|
|
if self._http_client:
|
|
await self._http_client.aclose()
|
|
|
|
async def acquire_token(self, username: str, password: str, auth_code: str) -> str:
|
|
"""
|
|
Exchange authorization code for OAuth access token.
|
|
|
|
Args:
|
|
username: Username for logging
|
|
password: Password (for logging/debugging)
|
|
auth_code: Authorization code from OAuth flow
|
|
|
|
Returns:
|
|
OAuth access token
|
|
"""
|
|
logger.info(f"Exchanging auth code for access token (user: {username})...")
|
|
|
|
if not self._http_client:
|
|
raise RuntimeError(
|
|
"HTTP client not initialized - use async context manager"
|
|
)
|
|
|
|
# Exchange authorization code for access token
|
|
token_response = await self._http_client.post(
|
|
self.token_endpoint,
|
|
data={
|
|
"grant_type": "authorization_code",
|
|
"code": auth_code,
|
|
"redirect_uri": self.callback_url,
|
|
"client_id": self.client_id,
|
|
"client_secret": self.client_secret,
|
|
},
|
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|
)
|
|
token_response.raise_for_status()
|
|
token_data = token_response.json()
|
|
|
|
access_token = token_data.get("access_token")
|
|
if not access_token:
|
|
raise ValueError(f"No access token in response for {username}")
|
|
|
|
logger.info(f"Successfully acquired OAuth token for {username}")
|
|
return access_token
|
|
|
|
async def add_user(self, username: str, password: str, token: str) -> UserProfile:
|
|
"""
|
|
Add a user to the pool with their OAuth token.
|
|
|
|
Args:
|
|
username: Username
|
|
password: Password (for future re-auth if needed)
|
|
token: OAuth access token
|
|
|
|
Returns:
|
|
UserProfile for the added user
|
|
"""
|
|
if username in self.users:
|
|
logger.warning(f"User {username} already in pool, updating token")
|
|
|
|
profile = UserProfile(username=username, password=password, token=token)
|
|
self.users[username] = profile
|
|
logger.info(f"Added user {username} to pool (total: {len(self.users)})")
|
|
return profile
|
|
|
|
async def create_user_session(
|
|
self, username: str, mcp_url: str = "http://localhost:8001/mcp"
|
|
) -> ClientSession:
|
|
"""
|
|
Create an MCP client session for a user.
|
|
|
|
Args:
|
|
username: Username to create session for
|
|
mcp_url: MCP server URL
|
|
|
|
Returns:
|
|
Initialized ClientSession
|
|
|
|
Raises:
|
|
KeyError: If user not in pool
|
|
"""
|
|
if username not in self.users:
|
|
raise KeyError(f"User {username} not in pool")
|
|
|
|
profile = self.users[username]
|
|
|
|
# Create streamable HTTP connection with OAuth token in Authorization header
|
|
# This matches the pattern from tests/conftest.py create_mcp_client_session()
|
|
headers = {"Authorization": f"Bearer {profile.token}"}
|
|
streamable_context = streamablehttp_client(mcp_url, headers=headers)
|
|
|
|
try:
|
|
read_stream, write_stream, _ = await streamable_context.__aenter__()
|
|
|
|
session = ClientSession(read_stream, write_stream)
|
|
await session.__aenter__()
|
|
await session.initialize()
|
|
|
|
# Store both session and context for proper cleanup
|
|
profile.session = session
|
|
profile.streamable_context = streamable_context
|
|
logger.info(f"Created MCP session for {username}")
|
|
return session
|
|
|
|
except Exception as e:
|
|
# Clean up streamable context if session creation failed
|
|
try:
|
|
await streamable_context.__aexit__(None, None, None)
|
|
except Exception as cleanup_error:
|
|
logger.debug(f"Error during cleanup: {cleanup_error}")
|
|
raise e
|
|
|
|
async def close_user_session(self, username: str):
|
|
"""Close the MCP session for a user."""
|
|
if username not in self.users:
|
|
return
|
|
|
|
profile = self.users[username]
|
|
|
|
# Close ClientSession
|
|
if profile.session:
|
|
try:
|
|
await profile.session.__aexit__(None, None, None)
|
|
except Exception as e:
|
|
logger.debug(f"Error closing session for {username}: {e}")
|
|
profile.session = None
|
|
|
|
# Close streamable context
|
|
if profile.streamable_context:
|
|
try:
|
|
await profile.streamable_context.__aexit__(None, None, None)
|
|
except Exception as e:
|
|
logger.debug(f"Error closing streamable context for {username}: {e}")
|
|
profile.streamable_context = None
|
|
|
|
async def close_all_sessions(self):
|
|
"""Close all user sessions."""
|
|
for username in list(self.users.keys()):
|
|
await self.close_user_session(username)
|
|
|
|
def get_user(self, username: str) -> UserProfile:
|
|
"""Get user profile by username."""
|
|
if username not in self.users:
|
|
raise KeyError(f"User {username} not in pool")
|
|
return self.users[username]
|
|
|
|
def get_all_users(self) -> list[UserProfile]:
|
|
"""Get all user profiles."""
|
|
return list(self.users.values())
|
|
|
|
def record_operation(self, username: str, success: bool = True):
|
|
"""Record an operation for user stats."""
|
|
if username in self.users:
|
|
self.users[username].operation_count += 1
|
|
if not success:
|
|
self.users[username].error_count += 1
|
|
|
|
def get_stats(self) -> dict[str, dict[str, int | float]]:
|
|
"""Get per-user operation statistics."""
|
|
return {
|
|
username: {
|
|
"operations": profile.operation_count,
|
|
"errors": profile.error_count,
|
|
"success_rate": (
|
|
(profile.operation_count - profile.error_count)
|
|
/ max(profile.operation_count, 1)
|
|
* 100
|
|
),
|
|
}
|
|
for username, profile in self.users.items()
|
|
}
|
|
|
|
async def create_nextcloud_user(
|
|
self,
|
|
username: str,
|
|
password: str,
|
|
display_name: str | None = None,
|
|
email: str | None = None,
|
|
) -> UserConfig:
|
|
"""
|
|
Create a Nextcloud user via the Users API.
|
|
|
|
Args:
|
|
username: Username for the new user
|
|
password: Password for the new user
|
|
display_name: Optional display name
|
|
email: Optional email address
|
|
|
|
Returns:
|
|
UserConfig for the created user
|
|
|
|
Raises:
|
|
HTTPStatusError: If user creation fails
|
|
"""
|
|
logger.info(f"Creating Nextcloud user: {username}")
|
|
|
|
await self.admin_client.users.create_user(
|
|
userid=username,
|
|
password=password,
|
|
display_name=display_name or username,
|
|
email=email or f"{username}@benchmark.local",
|
|
)
|
|
|
|
logger.info(f"Successfully created Nextcloud user: {username}")
|
|
|
|
return UserConfig(
|
|
username=username,
|
|
password=password,
|
|
display_name=display_name or username,
|
|
email=email or f"{username}@benchmark.local",
|
|
groups=[],
|
|
)
|
|
|
|
async def delete_nextcloud_user(self, username: str):
|
|
"""
|
|
Delete a Nextcloud user via the Users API.
|
|
|
|
Args:
|
|
username: Username to delete
|
|
"""
|
|
logger.info(f"Deleting Nextcloud user: {username}")
|
|
|
|
try:
|
|
await self.admin_client.users.delete_user(userid=username)
|
|
logger.info(f"Successfully deleted Nextcloud user: {username}")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to delete user {username}: {e}")
|
|
|
|
async def acquire_token_playwright(
|
|
self,
|
|
browser: Any,
|
|
username: str,
|
|
password: str,
|
|
state: str,
|
|
auth_states: dict[str, str],
|
|
) -> str:
|
|
"""
|
|
Acquire OAuth token via Playwright browser automation.
|
|
|
|
Based on conftest.py playwright_oauth_token fixture.
|
|
Automates the full OAuth flow:
|
|
1. Navigate to authorization URL
|
|
2. Fill login form
|
|
3. Handle OAuth consent
|
|
4. Wait for callback server to receive auth code
|
|
5. Exchange code for access token
|
|
|
|
Args:
|
|
browser: Playwright browser instance
|
|
username: Username to authenticate
|
|
password: Password for the user
|
|
state: Unique state parameter for this OAuth flow
|
|
auth_states: Dict mapping state -> auth_code (shared with callback server)
|
|
|
|
Returns:
|
|
OAuth access token
|
|
|
|
Raises:
|
|
TimeoutError: If callback not received within timeout
|
|
ValueError: If token exchange fails
|
|
"""
|
|
|
|
logger.info(f"Starting Playwright OAuth flow for {username}...")
|
|
logger.debug(f"Using state: {state[:16]}...")
|
|
|
|
# Construct authorization URL
|
|
auth_url = (
|
|
f"{self.authorization_endpoint}?"
|
|
f"response_type=code&"
|
|
f"client_id={self.client_id}&"
|
|
f"redirect_uri={quote(self.callback_url, safe='')}&"
|
|
f"state={state}&"
|
|
f"scope=openid%20profile%20email"
|
|
)
|
|
|
|
# Browser automation
|
|
context = await browser.new_context(ignore_https_errors=True)
|
|
page = await context.new_page()
|
|
|
|
try:
|
|
# Navigate to authorization URL
|
|
logger.debug("Navigating to authorization URL...")
|
|
await page.goto(auth_url, wait_until="networkidle", timeout=30000)
|
|
current_url = page.url
|
|
|
|
# Login if needed
|
|
if "/login" in current_url or "/index.php/login" in current_url:
|
|
logger.info(f"Logging in as {username}...")
|
|
await page.wait_for_selector('input[name="user"]', timeout=10000)
|
|
await page.fill('input[name="user"]', username)
|
|
await page.fill('input[name="password"]', password)
|
|
await page.click('button[type="submit"]')
|
|
await page.wait_for_load_state("networkidle", timeout=30000)
|
|
current_url = page.url
|
|
logger.info("Login completed")
|
|
|
|
# Handle OAuth consent if present
|
|
try:
|
|
authorize_button = await page.query_selector(
|
|
'button:has-text("Authorize"), button:has-text("Allow"), input[type="submit"][value*="uthoriz"]'
|
|
)
|
|
if authorize_button:
|
|
logger.info("Authorizing OAuth client...")
|
|
await authorize_button.click()
|
|
await page.wait_for_load_state("networkidle", timeout=10000)
|
|
except Exception as e:
|
|
logger.debug(f"No authorization needed: {e}")
|
|
|
|
# Wait for callback server to receive auth code
|
|
logger.info("Waiting for OAuth callback...")
|
|
timeout_seconds = 30
|
|
start_time = time.time()
|
|
while state not in auth_states:
|
|
if time.time() - start_time > timeout_seconds:
|
|
screenshot_path = f"/tmp/oauth_timeout_{username}.png"
|
|
await page.screenshot(path=screenshot_path)
|
|
logger.error(f"Screenshot saved to {screenshot_path}")
|
|
raise TimeoutError(
|
|
f"Timeout waiting for OAuth callback for {username}"
|
|
)
|
|
await anyio.sleep(0.5)
|
|
|
|
auth_code = auth_states[state]
|
|
logger.info(f"Received auth code for {username}")
|
|
|
|
finally:
|
|
await context.close()
|
|
|
|
# Exchange code for token
|
|
logger.info(f"Exchanging auth code for access token ({username})...")
|
|
token_response = await self._http_client.post(
|
|
self.token_endpoint,
|
|
data={
|
|
"grant_type": "authorization_code",
|
|
"code": auth_code,
|
|
"redirect_uri": self.callback_url,
|
|
"client_id": self.client_id,
|
|
"client_secret": self.client_secret,
|
|
},
|
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|
)
|
|
token_response.raise_for_status()
|
|
token_data = token_response.json()
|
|
|
|
access_token = token_data.get("access_token")
|
|
if not access_token:
|
|
raise ValueError(f"No access token for {username}: {token_data}")
|
|
|
|
logger.info(f"Successfully acquired OAuth token for {username}")
|
|
return access_token
|
|
|
|
|
|
class UserSessionWrapper:
|
|
"""
|
|
Wrapper for a user-specific MCP session with operation tracking.
|
|
|
|
Provides a convenient interface for executing operations as a specific user.
|
|
"""
|
|
|
|
def __init__(self, username: str, session: ClientSession, pool: OAuthUserPool):
|
|
self.username = username
|
|
self.session = session
|
|
self.pool = pool
|
|
|
|
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
|
"""
|
|
Call an MCP tool and record the operation.
|
|
|
|
Args:
|
|
tool_name: Name of the tool to call
|
|
arguments: Tool arguments
|
|
|
|
Returns:
|
|
Tool result
|
|
"""
|
|
try:
|
|
result = await self.session.call_tool(tool_name, arguments)
|
|
self.pool.record_operation(self.username, success=True)
|
|
return result
|
|
except Exception:
|
|
self.pool.record_operation(self.username, success=False)
|
|
raise
|
|
|
|
async def read_resource(self, uri: str) -> Any:
|
|
"""
|
|
Read an MCP resource and record the operation.
|
|
|
|
Args:
|
|
uri: Resource URI
|
|
|
|
Returns:
|
|
Resource data
|
|
"""
|
|
try:
|
|
result = await self.session.read_resource(uri)
|
|
self.pool.record_operation(self.username, success=True)
|
|
return result
|
|
except Exception:
|
|
self.pool.record_operation(self.username, success=False)
|
|
raise
|
|
|
|
|
|
def generate_secure_password(length: int = 20) -> str:
|
|
"""Generate a secure random password."""
|
|
|
|
alphabet = string.ascii_letters + string.digits + "!@#$%^&*()"
|
|
return "".join(secrets.choice(alphabet) for _ in range(length))
|