diff --git a/tests/conftest.py b/tests/conftest.py index 852f85d..90a0f84e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -651,8 +651,10 @@ def oauth_callback_server(): """ Fixture to create an HTTP server for OAuth callback handling. - Yields a tuple of (auth_state, server_url) where: - - auth_state: A dict with {"code": None} that will be populated with the auth code + Supports multiple concurrent OAuth flows using state parameters for correlation. + + Yields a tuple of (auth_states, server_url) where: + - auth_states: A dict mapping state parameter to auth code - server_url: The callback URL for the server (e.g., "http://localhost:8081") The server automatically shuts down when the fixture is torn down. @@ -663,8 +665,9 @@ def oauth_callback_server(): from http.server import BaseHTTPRequestHandler, HTTPServer from urllib.parse import parse_qs, urlparse - # Use a mutable container to share state across threads - auth_state = {"code": None} + # Use a dict to store auth codes keyed by state parameter + # This allows multiple concurrent OAuth flows + auth_states = {} httpd = None server_thread = None @@ -674,26 +677,27 @@ def oauth_callback_server(): pass def do_GET(self): - # Ignore subsequent requests if we already have a code - # (this is a session-scoped fixture, so only process the first auth code) - if auth_state["code"] is not None: - self.send_response(200) - self.send_header("Content-type", "text/html") - self.end_headers() - self.wfile.write( - b"

Authentication already completed

" - ) - return - # Parse the callback request parsed_path = urlparse(self.path) query = parse_qs(parsed_path.query) code = query.get("code", [None])[0] + state = query.get("state", [None])[0] # Only process if we have a valid code if code: - auth_state["code"] = code - logger.info(f"OAuth callback received. Code: {code[:20]}...") + # Store code keyed by state parameter for correlation + if state: + auth_states[state] = code + logger.info( + f"OAuth callback received for state={state[:16]}... Code: {code[:20]}..." + ) + else: + # Fallback for flows without state parameter (legacy interactive flow) + auth_states["_default"] = code + logger.info( + f"OAuth callback received (no state). Code: {code[:20]}..." + ) + self.send_response(200) self.send_header("Content-type", "text/html") self.end_headers() @@ -714,8 +718,8 @@ def oauth_callback_server(): server_thread.start() logger.info("OAuth callback server started on http://localhost:8081") - # Yield the auth state and server URL - yield auth_state, "http://localhost:8081" + # Yield the auth states dict and server URL + yield auth_states, "http://localhost:8081" finally: # Clean up the server @@ -746,8 +750,8 @@ async def interactive_oauth_token(oauth_callback_server) -> str: from nextcloud_mcp_server.auth.client_registration import load_or_register_client - # Unpack the server fixture - auth_state, callback_url = oauth_callback_server + # Unpack the server fixture (now returns dict of auth_states) + auth_states, callback_url = oauth_callback_server nextcloud_host = os.getenv("NEXTCLOUD_HOST") async with httpx.AsyncClient() as http_client: @@ -771,22 +775,22 @@ async def interactive_oauth_token(oauth_callback_server) -> str: "After logging in, the OAuth authorization will proceed automatically" ) - # Construct authorization URL + # Construct authorization URL (no state parameter for interactive flow) auth_url = f"{authorization_endpoint}?response_type=code&client_id={client_info.client_id}&redirect_uri={callback_url}&scope=openid%20profile%20email" # Open authorization URL in browser webbrowser.open(auth_url) - # Wait for auth code with timeout + # Wait for auth code with timeout (uses "_default" key for flows without state) timeout = 120 # 2 minutes start_time = time.time() - while not auth_state["code"]: + while "_default" not in auth_states: if time.time() - start_time > timeout: raise TimeoutError("OAuth authorization timed out after 2 minutes") logger.info("Waiting for OAuth authorization...") time.sleep(1) - auth_code = auth_state["code"] + auth_code = auth_states["_default"] logger.info("Received authorization code, exchanging for token...") token_response = await http_client.post( @@ -809,13 +813,15 @@ async def interactive_oauth_token(oauth_callback_server) -> str: @pytest.fixture(scope="session") -async def shared_oauth_client_credentials(): +async def shared_oauth_client_credentials(oauth_callback_server): """ Fixture to obtain shared OAuth client credentials that will be reused for all users. This registers a single OAuth client with Nextcloud that matches the MCP server's registration, allowing all test users to authenticate using the same client_id/secret. + Now uses the real OAuth callback server for reliable token acquisition. + Returns: Tuple of (client_id, client_secret, callback_url, token_endpoint, authorization_endpoint) """ @@ -825,7 +831,11 @@ async def shared_oauth_client_credentials(): if not nextcloud_host: pytest.skip("Shared OAuth client requires NEXTCLOUD_HOST") + # Get callback URL from the real callback server + auth_states, callback_url = oauth_callback_server + logger.info("Setting up shared OAuth client credentials for all test users...") + logger.info(f"Using real callback server at: {callback_url}") async with httpx.AsyncClient(timeout=30.0) as http_client: # OIDC Discovery @@ -841,9 +851,6 @@ async def shared_oauth_client_credentials(): if not all([token_endpoint, registration_endpoint, authorization_endpoint]): raise ValueError("OIDC discovery missing required endpoints") - # Use callback URL that won't actually be used (we extract code from browser URL) - callback_url = "http://localhost:9999/oauth/callback" - # Register or load shared OAuth client (matches MCP server registration) client_info = await load_or_register_client( nextcloud_url=nextcloud_host, @@ -866,7 +873,9 @@ async def shared_oauth_client_credentials(): @pytest.fixture(scope="session") -async def playwright_oauth_token(browser, shared_oauth_client_credentials) -> str: +async def playwright_oauth_token( + browser, shared_oauth_client_credentials, oauth_callback_server +) -> str: """ Fixture to obtain an OAuth access token using Playwright headless browser automation. @@ -875,7 +884,7 @@ async def playwright_oauth_token(browser, shared_oauth_client_credentials) -> st 2. Navigating to authorization URL in headless browser 3. Programmatically filling in login form 4. Handling OAuth consent - 5. Extracting auth code from redirect + 5. Waiting for callback server to receive auth code (NEW: using real callback server!) 6. Exchanging code for access token Environment variables required: @@ -888,7 +897,9 @@ async def playwright_oauth_token(browser, shared_oauth_client_credentials) -> st - Browser fixture provided by pytest-playwright-asyncio - See: https://playwright.dev/python/docs/test-runners """ - from urllib.parse import parse_qs, urlparse + import secrets + import time + from urllib.parse import quote nextcloud_host = os.getenv("NEXTCLOUD_HOST") username = os.getenv("NEXTCLOUD_USERNAME") @@ -899,6 +910,9 @@ async def playwright_oauth_token(browser, shared_oauth_client_credentials) -> st "Playwright OAuth requires NEXTCLOUD_HOST, NEXTCLOUD_USERNAME, and NEXTCLOUD_PASSWORD" ) + # Get auth_states dict from callback server + auth_states, _ = oauth_callback_server + # Unpack shared client credentials client_id, client_secret, callback_url, token_endpoint, authorization_endpoint = ( shared_oauth_client_credentials @@ -906,13 +920,19 @@ async def playwright_oauth_token(browser, shared_oauth_client_credentials) -> st logger.info(f"Starting Playwright-based OAuth flow for {username}...") logger.info(f"Using shared OAuth client: {client_id[:16]}...") + logger.info(f"Using real callback server at: {callback_url}") - # Construct authorization URL + # Generate unique state parameter for this OAuth flow + state = secrets.token_urlsafe(32) + logger.debug(f"Generated state: {state[:16]}...") + + # Construct authorization URL with state parameter auth_url = ( f"{authorization_endpoint}?" f"response_type=code&" f"client_id={client_id}&" - f"redirect_uri={callback_url}&" + f"redirect_uri={quote(callback_url, safe='')}&" + f"state={state}&" f"scope=openid%20profile%20email" ) @@ -969,33 +989,24 @@ async def playwright_oauth_token(browser, shared_oauth_client_credentials) -> st except Exception as e: logger.debug(f"No authorization button found or already authorized: {e}") - # Wait for redirect to callback URL (which will fail to load, but we just need the URL) - try: - # The redirect might fail since localhost:9999 isn't actually running - # But we can still extract the code from the URL - await page.wait_for_url(f"{callback_url}*", timeout=10000) - except Exception as e: - # Expected - the callback URL won't load, but we should have the URL - logger.debug(f"Callback redirect (expected to fail): {e}") + # Wait for callback server to receive the auth code + # Browser will be redirected to localhost:8081 which will capture the code + logger.info("Waiting for callback server to receive auth code...") + timeout_seconds = 30 + start_time = time.time() + while state not in auth_states: + if time.time() - start_time > timeout_seconds: + # Take a screenshot for debugging + screenshot_path = "/tmp/playwright_oauth_error.png" + await page.screenshot(path=screenshot_path) + logger.error(f"Screenshot saved to {screenshot_path}") + raise TimeoutError( + f"Timeout waiting for OAuth callback (state={state[:16]}...)" + ) + await asyncio.sleep(0.5) - # Extract auth code from URL - final_url = page.url - logger.debug(f"Final URL: {final_url}") - - parsed_url = urlparse(final_url) - query_params = parse_qs(parsed_url.query) - auth_code = query_params.get("code", [None])[0] - - if not auth_code: - # Take a screenshot for debugging - screenshot_path = "/tmp/playwright_oauth_error.png" - await page.screenshot(path=screenshot_path) - logger.error(f"Screenshot saved to {screenshot_path}") - raise ValueError( - f"No authorization code found in redirect URL: {final_url}" - ) - - logger.info(f"Successfully extracted authorization code: {auth_code[:20]}...") + auth_code = auth_states[state] + logger.info(f"Successfully received authorization code: {auth_code[:20]}...") finally: await context.close() @@ -1234,23 +1245,31 @@ async def test_users_setup(nc_client: NextcloudClient): async def _get_oauth_token_for_user( - browser, shared_oauth_client_credentials, username: str, password: str + browser, + shared_oauth_client_credentials, + auth_states, + username: str, + password: str, ) -> str: """ Helper function to get OAuth access token for a user via Playwright. Uses shared OAuth client credentials to authenticate multiple users with the same client. + Now uses real callback server with state parameters for reliable token acquisition. Args: browser: Playwright browser instance shared_oauth_client_credentials: Tuple of (client_id, client_secret, callback_url, token_endpoint, authorization_endpoint) + auth_states: Dict mapping state parameters to auth codes (from callback server) username: Username to authenticate as password: Password for the user Returns: OAuth access token string """ - from urllib.parse import parse_qs, urlparse + import secrets + import time + from urllib.parse import quote nextcloud_host = os.getenv("NEXTCLOUD_HOST") @@ -1265,14 +1284,17 @@ async def _get_oauth_token_for_user( logger.info(f"Getting OAuth token for user: {username}...") logger.info(f"Using shared OAuth client: {client_id[:16]}...") - # Construct authorization URL with properly encoded redirect_uri - from urllib.parse import quote + # Generate unique state parameter for this OAuth flow + state = secrets.token_urlsafe(32) + logger.debug(f"Generated state for {username}: {state[:16]}...") + # Construct authorization URL with state parameter auth_url = ( f"{authorization_endpoint}?" f"response_type=code&" f"client_id={client_id}&" f"redirect_uri={quote(callback_url, safe='')}&" + f"state={state}&" f"scope=openid%20profile%20email" ) @@ -1309,22 +1331,25 @@ async def _get_oauth_token_for_user( except Exception as e: logger.debug(f"No authorization needed for {username}: {e}") - # Wait for redirect and extract auth code - try: - await page.wait_for_url(f"{callback_url}*", timeout=30000) - except Exception: - pass # Expected - callback won't load - - final_url = page.url - parsed_url = urlparse(final_url) - query_params = parse_qs(parsed_url.query) - auth_code = query_params.get("code", [None])[0] - - if not auth_code: - raise ValueError( - f"No authorization code found for {username} in URL: {final_url}" - ) + # Wait for callback server to receive the auth code + # Browser will be redirected to localhost:8081 which will capture the code + logger.info( + f"Waiting for callback server to receive auth code for {username}..." + ) + timeout_seconds = 30 + start_time = time.time() + while state not in auth_states: + if time.time() - start_time > timeout_seconds: + # Take screenshot for debugging + screenshot_path = f"/tmp/playwright_oauth_timeout_{username}.png" + await page.screenshot(path=screenshot_path) + logger.error(f"Screenshot saved to {screenshot_path}") + raise TimeoutError( + f"Timeout waiting for OAuth callback for {username} (state={state[:16]}...)" + ) + await asyncio.sleep(0.5) + auth_code = auth_states[state] logger.info(f"Got auth code for {username}: {auth_code[:20]}...") finally: @@ -1358,7 +1383,7 @@ async def _get_oauth_token_for_user( # Parallel token retrieval fixture - fetches all OAuth tokens concurrently @pytest.fixture(scope="session") async def all_oauth_tokens( - browser, shared_oauth_client_credentials, test_users_setup + browser, shared_oauth_client_credentials, test_users_setup, oauth_callback_server ) -> dict[str, str]: """ Fetch OAuth tokens for all test users in parallel for speed. @@ -1366,26 +1391,34 @@ async def all_oauth_tokens( Returns a dict mapping username to OAuth access token. This is significantly faster than fetching tokens sequentially. - Note: We add a small stagger between starting each flow to avoid - race conditions in Nextcloud's OAuth session handling. + Now uses the real callback server with state parameters for reliable + concurrent token acquisition without race conditions. """ import asyncio import time + # Get auth_states dict from callback server + auth_states, callback_url = oauth_callback_server + start_time = time.time() logger.info("Fetching OAuth tokens for all users in parallel...") + logger.info(f"Using callback server at {callback_url} with state-based correlation") async def get_token_with_delay(username: str, config: dict, delay: float): """Get token for a user after a small delay to stagger requests.""" if delay > 0: await asyncio.sleep(delay) return await _get_oauth_token_for_user( - browser, shared_oauth_client_credentials, username, config["password"] + browser, + shared_oauth_client_credentials, + auth_states, + username, + config["password"], ) # Create tasks for all users with staggered starts (2.0s apart) tasks = { - username: get_token_with_delay(username, config, (idx + 1) * 2.0) + username: get_token_with_delay(username, config, idx * 0.5) for idx, (username, config) in enumerate(test_users_setup.items()) }