""" OAuth User Pool Management for Load Testing. Manages multiple OAuth-authenticated users for realistic multi-user load testing scenarios. """ import logging import secrets import string import time from dataclasses import dataclass from typing import Any from urllib.parse import quote import anyio import httpx from mcp import ClientSession from mcp.client.streamable_http import streamablehttp_client logger = logging.getLogger(__name__) @dataclass class UserConfig: """Configuration for a single test user.""" username: str password: str display_name: str email: str groups: list[str] @dataclass class UserProfile: """Profile for an OAuth-authenticated user.""" username: str password: str token: str session: ClientSession | None = None streamable_context: Any | None = None # Store for proper cleanup operation_count: int = 0 error_count: int = 0 class OAuthUserPool: """ Manages a pool of OAuth-authenticated users for load testing. Handles token acquisition, session management, and user lifecycle. """ def __init__( self, admin_client: Any, # NextcloudClient with admin credentials client_id: str, client_secret: str, callback_url: str, token_endpoint: str, authorization_endpoint: str, ): self.admin_client = admin_client # For user management self.nextcloud_host = str(admin_client._client.base_url) self.client_id = client_id self.client_secret = client_secret self.callback_url = callback_url self.token_endpoint = token_endpoint self.authorization_endpoint = authorization_endpoint self.users: dict[str, UserProfile] = {} self._http_client: httpx.AsyncClient | None = None async def __aenter__(self): """Initialize HTTP client.""" self._http_client = httpx.AsyncClient(verify=False, timeout=30.0) return self async def __aexit__(self, exc_type, exc_val, exc_tb): """Cleanup HTTP client.""" if self._http_client: await self._http_client.aclose() async def acquire_token(self, username: str, password: str, auth_code: str) -> str: """ Exchange authorization code for OAuth access token. Args: username: Username for logging password: Password (for logging/debugging) auth_code: Authorization code from OAuth flow Returns: OAuth access token """ logger.info(f"Exchanging auth code for access token (user: {username})...") if not self._http_client: raise RuntimeError( "HTTP client not initialized - use async context manager" ) # Exchange authorization code for access token token_response = await self._http_client.post( self.token_endpoint, data={ "grant_type": "authorization_code", "code": auth_code, "redirect_uri": self.callback_url, "client_id": self.client_id, "client_secret": self.client_secret, }, headers={"Content-Type": "application/x-www-form-urlencoded"}, ) token_response.raise_for_status() token_data = token_response.json() access_token = token_data.get("access_token") if not access_token: raise ValueError(f"No access token in response for {username}") logger.info(f"Successfully acquired OAuth token for {username}") return access_token async def add_user(self, username: str, password: str, token: str) -> UserProfile: """ Add a user to the pool with their OAuth token. Args: username: Username password: Password (for future re-auth if needed) token: OAuth access token Returns: UserProfile for the added user """ if username in self.users: logger.warning(f"User {username} already in pool, updating token") profile = UserProfile(username=username, password=password, token=token) self.users[username] = profile logger.info(f"Added user {username} to pool (total: {len(self.users)})") return profile async def create_user_session( self, username: str, mcp_url: str = "http://localhost:8001/mcp" ) -> ClientSession: """ Create an MCP client session for a user. Args: username: Username to create session for mcp_url: MCP server URL Returns: Initialized ClientSession Raises: KeyError: If user not in pool """ if username not in self.users: raise KeyError(f"User {username} not in pool") profile = self.users[username] # Create streamable HTTP connection with OAuth token in Authorization header # This matches the pattern from tests/conftest.py create_mcp_client_session() headers = {"Authorization": f"Bearer {profile.token}"} streamable_context = streamablehttp_client(mcp_url, headers=headers) try: read_stream, write_stream, _ = await streamable_context.__aenter__() session = ClientSession(read_stream, write_stream) await session.__aenter__() await session.initialize() # Store both session and context for proper cleanup profile.session = session profile.streamable_context = streamable_context logger.info(f"Created MCP session for {username}") return session except Exception as e: # Clean up streamable context if session creation failed try: await streamable_context.__aexit__(None, None, None) except Exception as cleanup_error: logger.debug(f"Error during cleanup: {cleanup_error}") raise e async def close_user_session(self, username: str): """Close the MCP session for a user.""" if username not in self.users: return profile = self.users[username] # Close ClientSession if profile.session: try: await profile.session.__aexit__(None, None, None) except Exception as e: logger.debug(f"Error closing session for {username}: {e}") profile.session = None # Close streamable context if profile.streamable_context: try: await profile.streamable_context.__aexit__(None, None, None) except Exception as e: logger.debug(f"Error closing streamable context for {username}: {e}") profile.streamable_context = None async def close_all_sessions(self): """Close all user sessions.""" for username in list(self.users.keys()): await self.close_user_session(username) def get_user(self, username: str) -> UserProfile: """Get user profile by username.""" if username not in self.users: raise KeyError(f"User {username} not in pool") return self.users[username] def get_all_users(self) -> list[UserProfile]: """Get all user profiles.""" return list(self.users.values()) def record_operation(self, username: str, success: bool = True): """Record an operation for user stats.""" if username in self.users: self.users[username].operation_count += 1 if not success: self.users[username].error_count += 1 def get_stats(self) -> dict[str, dict[str, int | float]]: """Get per-user operation statistics.""" return { username: { "operations": profile.operation_count, "errors": profile.error_count, "success_rate": ( (profile.operation_count - profile.error_count) / max(profile.operation_count, 1) * 100 ), } for username, profile in self.users.items() } async def create_nextcloud_user( self, username: str, password: str, display_name: str | None = None, email: str | None = None, ) -> UserConfig: """ Create a Nextcloud user via the Users API. Args: username: Username for the new user password: Password for the new user display_name: Optional display name email: Optional email address Returns: UserConfig for the created user Raises: HTTPStatusError: If user creation fails """ logger.info(f"Creating Nextcloud user: {username}") await self.admin_client.users.create_user( userid=username, password=password, display_name=display_name or username, email=email or f"{username}@benchmark.local", ) logger.info(f"Successfully created Nextcloud user: {username}") return UserConfig( username=username, password=password, display_name=display_name or username, email=email or f"{username}@benchmark.local", groups=[], ) async def delete_nextcloud_user(self, username: str): """ Delete a Nextcloud user via the Users API. Args: username: Username to delete """ logger.info(f"Deleting Nextcloud user: {username}") try: await self.admin_client.users.delete_user(userid=username) logger.info(f"Successfully deleted Nextcloud user: {username}") except Exception as e: logger.warning(f"Failed to delete user {username}: {e}") async def acquire_token_playwright( self, browser: Any, username: str, password: str, state: str, auth_states: dict[str, str], ) -> str: """ Acquire OAuth token via Playwright browser automation. Based on conftest.py playwright_oauth_token fixture. Automates the full OAuth flow: 1. Navigate to authorization URL 2. Fill login form 3. Handle OAuth consent 4. Wait for callback server to receive auth code 5. Exchange code for access token Args: browser: Playwright browser instance username: Username to authenticate password: Password for the user state: Unique state parameter for this OAuth flow auth_states: Dict mapping state -> auth_code (shared with callback server) Returns: OAuth access token Raises: TimeoutError: If callback not received within timeout ValueError: If token exchange fails """ logger.info(f"Starting Playwright OAuth flow for {username}...") logger.debug(f"Using state: {state[:16]}...") # Construct authorization URL auth_url = ( f"{self.authorization_endpoint}?" f"response_type=code&" f"client_id={self.client_id}&" f"redirect_uri={quote(self.callback_url, safe='')}&" f"state={state}&" f"scope=openid%20profile%20email" ) # Browser automation context = await browser.new_context(ignore_https_errors=True) page = await context.new_page() try: # Navigate to authorization URL logger.debug("Navigating to authorization URL...") await page.goto(auth_url, wait_until="networkidle", timeout=30000) current_url = page.url # Login if needed if "/login" in current_url or "/index.php/login" in current_url: logger.info(f"Logging in as {username}...") await page.wait_for_selector('input[name="user"]', timeout=10000) await page.fill('input[name="user"]', username) await page.fill('input[name="password"]', password) await page.click('button[type="submit"]') await page.wait_for_load_state("networkidle", timeout=30000) current_url = page.url logger.info("Login completed") # Handle OAuth consent if present try: authorize_button = await page.query_selector( 'button:has-text("Authorize"), button:has-text("Allow"), input[type="submit"][value*="uthoriz"]' ) if authorize_button: logger.info("Authorizing OAuth client...") await authorize_button.click() await page.wait_for_load_state("networkidle", timeout=10000) except Exception as e: logger.debug(f"No authorization needed: {e}") # Wait for callback server to receive auth code logger.info("Waiting for OAuth callback...") timeout_seconds = 30 start_time = time.time() while state not in auth_states: if time.time() - start_time > timeout_seconds: screenshot_path = f"/tmp/oauth_timeout_{username}.png" await page.screenshot(path=screenshot_path) logger.error(f"Screenshot saved to {screenshot_path}") raise TimeoutError( f"Timeout waiting for OAuth callback for {username}" ) await anyio.sleep(0.5) auth_code = auth_states[state] logger.info(f"Received auth code for {username}") finally: await context.close() # Exchange code for token logger.info(f"Exchanging auth code for access token ({username})...") token_response = await self._http_client.post( self.token_endpoint, data={ "grant_type": "authorization_code", "code": auth_code, "redirect_uri": self.callback_url, "client_id": self.client_id, "client_secret": self.client_secret, }, headers={"Content-Type": "application/x-www-form-urlencoded"}, ) token_response.raise_for_status() token_data = token_response.json() access_token = token_data.get("access_token") if not access_token: raise ValueError(f"No access token for {username}: {token_data}") logger.info(f"Successfully acquired OAuth token for {username}") return access_token class UserSessionWrapper: """ Wrapper for a user-specific MCP session with operation tracking. Provides a convenient interface for executing operations as a specific user. """ def __init__(self, username: str, session: ClientSession, pool: OAuthUserPool): self.username = username self.session = session self.pool = pool async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any: """ Call an MCP tool and record the operation. Args: tool_name: Name of the tool to call arguments: Tool arguments Returns: Tool result """ try: result = await self.session.call_tool(tool_name, arguments) self.pool.record_operation(self.username, success=True) return result except Exception: self.pool.record_operation(self.username, success=False) raise async def read_resource(self, uri: str) -> Any: """ Read an MCP resource and record the operation. Args: uri: Resource URI Returns: Resource data """ try: result = await self.session.read_resource(uri) self.pool.record_operation(self.username, success=True) return result except Exception: self.pool.record_operation(self.username, success=False) raise def generate_secure_password(length: int = 20) -> str: """Generate a secure random password.""" alphabet = string.ascii_letters + string.digits + "!@#$%^&*()" return "".join(secrets.choice(alphabet) for _ in range(length))