test: Fix oauth tests by reusing callback server

This commit is contained in:
Chris Coutinho
2025-10-15 17:06:46 +02:00
parent 3ad9198f36
commit 46c6f2f294
+117 -84
View File
@@ -651,8 +651,10 @@ def oauth_callback_server():
"""
Fixture to create an HTTP server for OAuth callback handling.
Yields a tuple of (auth_state, server_url) where:
- auth_state: A dict with {"code": None} that will be populated with the auth code
Supports multiple concurrent OAuth flows using state parameters for correlation.
Yields a tuple of (auth_states, server_url) where:
- auth_states: A dict mapping state parameter to auth code
- server_url: The callback URL for the server (e.g., "http://localhost:8081")
The server automatically shuts down when the fixture is torn down.
@@ -663,8 +665,9 @@ def oauth_callback_server():
from http.server import BaseHTTPRequestHandler, HTTPServer
from urllib.parse import parse_qs, urlparse
# Use a mutable container to share state across threads
auth_state = {"code": None}
# Use a dict to store auth codes keyed by state parameter
# This allows multiple concurrent OAuth flows
auth_states = {}
httpd = None
server_thread = None
@@ -674,26 +677,27 @@ def oauth_callback_server():
pass
def do_GET(self):
# Ignore subsequent requests if we already have a code
# (this is a session-scoped fixture, so only process the first auth code)
if auth_state["code"] is not None:
self.send_response(200)
self.send_header("Content-type", "text/html")
self.end_headers()
self.wfile.write(
b"<html><body><h1>Authentication already completed</h1></body></html>"
)
return
# Parse the callback request
parsed_path = urlparse(self.path)
query = parse_qs(parsed_path.query)
code = query.get("code", [None])[0]
state = query.get("state", [None])[0]
# Only process if we have a valid code
if code:
auth_state["code"] = code
logger.info(f"OAuth callback received. Code: {code[:20]}...")
# Store code keyed by state parameter for correlation
if state:
auth_states[state] = code
logger.info(
f"OAuth callback received for state={state[:16]}... Code: {code[:20]}..."
)
else:
# Fallback for flows without state parameter (legacy interactive flow)
auth_states["_default"] = code
logger.info(
f"OAuth callback received (no state). Code: {code[:20]}..."
)
self.send_response(200)
self.send_header("Content-type", "text/html")
self.end_headers()
@@ -714,8 +718,8 @@ def oauth_callback_server():
server_thread.start()
logger.info("OAuth callback server started on http://localhost:8081")
# Yield the auth state and server URL
yield auth_state, "http://localhost:8081"
# Yield the auth states dict and server URL
yield auth_states, "http://localhost:8081"
finally:
# Clean up the server
@@ -746,8 +750,8 @@ async def interactive_oauth_token(oauth_callback_server) -> str:
from nextcloud_mcp_server.auth.client_registration import load_or_register_client
# Unpack the server fixture
auth_state, callback_url = oauth_callback_server
# Unpack the server fixture (now returns dict of auth_states)
auth_states, callback_url = oauth_callback_server
nextcloud_host = os.getenv("NEXTCLOUD_HOST")
async with httpx.AsyncClient() as http_client:
@@ -771,22 +775,22 @@ async def interactive_oauth_token(oauth_callback_server) -> str:
"After logging in, the OAuth authorization will proceed automatically"
)
# Construct authorization URL
# Construct authorization URL (no state parameter for interactive flow)
auth_url = f"{authorization_endpoint}?response_type=code&client_id={client_info.client_id}&redirect_uri={callback_url}&scope=openid%20profile%20email"
# Open authorization URL in browser
webbrowser.open(auth_url)
# Wait for auth code with timeout
# Wait for auth code with timeout (uses "_default" key for flows without state)
timeout = 120 # 2 minutes
start_time = time.time()
while not auth_state["code"]:
while "_default" not in auth_states:
if time.time() - start_time > timeout:
raise TimeoutError("OAuth authorization timed out after 2 minutes")
logger.info("Waiting for OAuth authorization...")
time.sleep(1)
auth_code = auth_state["code"]
auth_code = auth_states["_default"]
logger.info("Received authorization code, exchanging for token...")
token_response = await http_client.post(
@@ -809,13 +813,15 @@ async def interactive_oauth_token(oauth_callback_server) -> str:
@pytest.fixture(scope="session")
async def shared_oauth_client_credentials():
async def shared_oauth_client_credentials(oauth_callback_server):
"""
Fixture to obtain shared OAuth client credentials that will be reused for all users.
This registers a single OAuth client with Nextcloud that matches the MCP server's
registration, allowing all test users to authenticate using the same client_id/secret.
Now uses the real OAuth callback server for reliable token acquisition.
Returns:
Tuple of (client_id, client_secret, callback_url, token_endpoint, authorization_endpoint)
"""
@@ -825,7 +831,11 @@ async def shared_oauth_client_credentials():
if not nextcloud_host:
pytest.skip("Shared OAuth client requires NEXTCLOUD_HOST")
# Get callback URL from the real callback server
auth_states, callback_url = oauth_callback_server
logger.info("Setting up shared OAuth client credentials for all test users...")
logger.info(f"Using real callback server at: {callback_url}")
async with httpx.AsyncClient(timeout=30.0) as http_client:
# OIDC Discovery
@@ -841,9 +851,6 @@ async def shared_oauth_client_credentials():
if not all([token_endpoint, registration_endpoint, authorization_endpoint]):
raise ValueError("OIDC discovery missing required endpoints")
# Use callback URL that won't actually be used (we extract code from browser URL)
callback_url = "http://localhost:9999/oauth/callback"
# Register or load shared OAuth client (matches MCP server registration)
client_info = await load_or_register_client(
nextcloud_url=nextcloud_host,
@@ -866,7 +873,9 @@ async def shared_oauth_client_credentials():
@pytest.fixture(scope="session")
async def playwright_oauth_token(browser, shared_oauth_client_credentials) -> str:
async def playwright_oauth_token(
browser, shared_oauth_client_credentials, oauth_callback_server
) -> str:
"""
Fixture to obtain an OAuth access token using Playwright headless browser automation.
@@ -875,7 +884,7 @@ async def playwright_oauth_token(browser, shared_oauth_client_credentials) -> st
2. Navigating to authorization URL in headless browser
3. Programmatically filling in login form
4. Handling OAuth consent
5. Extracting auth code from redirect
5. Waiting for callback server to receive auth code (NEW: using real callback server!)
6. Exchanging code for access token
Environment variables required:
@@ -888,7 +897,9 @@ async def playwright_oauth_token(browser, shared_oauth_client_credentials) -> st
- Browser fixture provided by pytest-playwright-asyncio
- See: https://playwright.dev/python/docs/test-runners
"""
from urllib.parse import parse_qs, urlparse
import secrets
import time
from urllib.parse import quote
nextcloud_host = os.getenv("NEXTCLOUD_HOST")
username = os.getenv("NEXTCLOUD_USERNAME")
@@ -899,6 +910,9 @@ async def playwright_oauth_token(browser, shared_oauth_client_credentials) -> st
"Playwright OAuth requires NEXTCLOUD_HOST, NEXTCLOUD_USERNAME, and NEXTCLOUD_PASSWORD"
)
# Get auth_states dict from callback server
auth_states, _ = oauth_callback_server
# Unpack shared client credentials
client_id, client_secret, callback_url, token_endpoint, authorization_endpoint = (
shared_oauth_client_credentials
@@ -906,13 +920,19 @@ async def playwright_oauth_token(browser, shared_oauth_client_credentials) -> st
logger.info(f"Starting Playwright-based OAuth flow for {username}...")
logger.info(f"Using shared OAuth client: {client_id[:16]}...")
logger.info(f"Using real callback server at: {callback_url}")
# Construct authorization URL
# Generate unique state parameter for this OAuth flow
state = secrets.token_urlsafe(32)
logger.debug(f"Generated state: {state[:16]}...")
# Construct authorization URL with state parameter
auth_url = (
f"{authorization_endpoint}?"
f"response_type=code&"
f"client_id={client_id}&"
f"redirect_uri={callback_url}&"
f"redirect_uri={quote(callback_url, safe='')}&"
f"state={state}&"
f"scope=openid%20profile%20email"
)
@@ -969,33 +989,24 @@ async def playwright_oauth_token(browser, shared_oauth_client_credentials) -> st
except Exception as e:
logger.debug(f"No authorization button found or already authorized: {e}")
# Wait for redirect to callback URL (which will fail to load, but we just need the URL)
try:
# The redirect might fail since localhost:9999 isn't actually running
# But we can still extract the code from the URL
await page.wait_for_url(f"{callback_url}*", timeout=10000)
except Exception as e:
# Expected - the callback URL won't load, but we should have the URL
logger.debug(f"Callback redirect (expected to fail): {e}")
# Wait for callback server to receive the auth code
# Browser will be redirected to localhost:8081 which will capture the code
logger.info("Waiting for callback server to receive auth code...")
timeout_seconds = 30
start_time = time.time()
while state not in auth_states:
if time.time() - start_time > timeout_seconds:
# Take a screenshot for debugging
screenshot_path = "/tmp/playwright_oauth_error.png"
await page.screenshot(path=screenshot_path)
logger.error(f"Screenshot saved to {screenshot_path}")
raise TimeoutError(
f"Timeout waiting for OAuth callback (state={state[:16]}...)"
)
await asyncio.sleep(0.5)
# Extract auth code from URL
final_url = page.url
logger.debug(f"Final URL: {final_url}")
parsed_url = urlparse(final_url)
query_params = parse_qs(parsed_url.query)
auth_code = query_params.get("code", [None])[0]
if not auth_code:
# Take a screenshot for debugging
screenshot_path = "/tmp/playwright_oauth_error.png"
await page.screenshot(path=screenshot_path)
logger.error(f"Screenshot saved to {screenshot_path}")
raise ValueError(
f"No authorization code found in redirect URL: {final_url}"
)
logger.info(f"Successfully extracted authorization code: {auth_code[:20]}...")
auth_code = auth_states[state]
logger.info(f"Successfully received authorization code: {auth_code[:20]}...")
finally:
await context.close()
@@ -1234,23 +1245,31 @@ async def test_users_setup(nc_client: NextcloudClient):
async def _get_oauth_token_for_user(
browser, shared_oauth_client_credentials, username: str, password: str
browser,
shared_oauth_client_credentials,
auth_states,
username: str,
password: str,
) -> str:
"""
Helper function to get OAuth access token for a user via Playwright.
Uses shared OAuth client credentials to authenticate multiple users with the same client.
Now uses real callback server with state parameters for reliable token acquisition.
Args:
browser: Playwright browser instance
shared_oauth_client_credentials: Tuple of (client_id, client_secret, callback_url, token_endpoint, authorization_endpoint)
auth_states: Dict mapping state parameters to auth codes (from callback server)
username: Username to authenticate as
password: Password for the user
Returns:
OAuth access token string
"""
from urllib.parse import parse_qs, urlparse
import secrets
import time
from urllib.parse import quote
nextcloud_host = os.getenv("NEXTCLOUD_HOST")
@@ -1265,14 +1284,17 @@ async def _get_oauth_token_for_user(
logger.info(f"Getting OAuth token for user: {username}...")
logger.info(f"Using shared OAuth client: {client_id[:16]}...")
# Construct authorization URL with properly encoded redirect_uri
from urllib.parse import quote
# Generate unique state parameter for this OAuth flow
state = secrets.token_urlsafe(32)
logger.debug(f"Generated state for {username}: {state[:16]}...")
# Construct authorization URL with state parameter
auth_url = (
f"{authorization_endpoint}?"
f"response_type=code&"
f"client_id={client_id}&"
f"redirect_uri={quote(callback_url, safe='')}&"
f"state={state}&"
f"scope=openid%20profile%20email"
)
@@ -1309,22 +1331,25 @@ async def _get_oauth_token_for_user(
except Exception as e:
logger.debug(f"No authorization needed for {username}: {e}")
# Wait for redirect and extract auth code
try:
await page.wait_for_url(f"{callback_url}*", timeout=30000)
except Exception:
pass # Expected - callback won't load
final_url = page.url
parsed_url = urlparse(final_url)
query_params = parse_qs(parsed_url.query)
auth_code = query_params.get("code", [None])[0]
if not auth_code:
raise ValueError(
f"No authorization code found for {username} in URL: {final_url}"
)
# Wait for callback server to receive the auth code
# Browser will be redirected to localhost:8081 which will capture the code
logger.info(
f"Waiting for callback server to receive auth code for {username}..."
)
timeout_seconds = 30
start_time = time.time()
while state not in auth_states:
if time.time() - start_time > timeout_seconds:
# Take screenshot for debugging
screenshot_path = f"/tmp/playwright_oauth_timeout_{username}.png"
await page.screenshot(path=screenshot_path)
logger.error(f"Screenshot saved to {screenshot_path}")
raise TimeoutError(
f"Timeout waiting for OAuth callback for {username} (state={state[:16]}...)"
)
await asyncio.sleep(0.5)
auth_code = auth_states[state]
logger.info(f"Got auth code for {username}: {auth_code[:20]}...")
finally:
@@ -1358,7 +1383,7 @@ async def _get_oauth_token_for_user(
# Parallel token retrieval fixture - fetches all OAuth tokens concurrently
@pytest.fixture(scope="session")
async def all_oauth_tokens(
browser, shared_oauth_client_credentials, test_users_setup
browser, shared_oauth_client_credentials, test_users_setup, oauth_callback_server
) -> dict[str, str]:
"""
Fetch OAuth tokens for all test users in parallel for speed.
@@ -1366,26 +1391,34 @@ async def all_oauth_tokens(
Returns a dict mapping username to OAuth access token.
This is significantly faster than fetching tokens sequentially.
Note: We add a small stagger between starting each flow to avoid
race conditions in Nextcloud's OAuth session handling.
Now uses the real callback server with state parameters for reliable
concurrent token acquisition without race conditions.
"""
import asyncio
import time
# Get auth_states dict from callback server
auth_states, callback_url = oauth_callback_server
start_time = time.time()
logger.info("Fetching OAuth tokens for all users in parallel...")
logger.info(f"Using callback server at {callback_url} with state-based correlation")
async def get_token_with_delay(username: str, config: dict, delay: float):
"""Get token for a user after a small delay to stagger requests."""
if delay > 0:
await asyncio.sleep(delay)
return await _get_oauth_token_for_user(
browser, shared_oauth_client_credentials, username, config["password"]
browser,
shared_oauth_client_credentials,
auth_states,
username,
config["password"],
)
# Create tasks for all users with staggered starts (2.0s apart)
tasks = {
username: get_token_with_delay(username, config, (idx + 1) * 2.0)
username: get_token_with_delay(username, config, idx * 0.5)
for idx, (username, config) in enumerate(test_users_setup.items())
}