test: Fix oauth tests by reusing callback server
This commit is contained in:
+117
-84
@@ -651,8 +651,10 @@ def oauth_callback_server():
|
||||
"""
|
||||
Fixture to create an HTTP server for OAuth callback handling.
|
||||
|
||||
Yields a tuple of (auth_state, server_url) where:
|
||||
- auth_state: A dict with {"code": None} that will be populated with the auth code
|
||||
Supports multiple concurrent OAuth flows using state parameters for correlation.
|
||||
|
||||
Yields a tuple of (auth_states, server_url) where:
|
||||
- auth_states: A dict mapping state parameter to auth code
|
||||
- server_url: The callback URL for the server (e.g., "http://localhost:8081")
|
||||
|
||||
The server automatically shuts down when the fixture is torn down.
|
||||
@@ -663,8 +665,9 @@ def oauth_callback_server():
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
# Use a mutable container to share state across threads
|
||||
auth_state = {"code": None}
|
||||
# Use a dict to store auth codes keyed by state parameter
|
||||
# This allows multiple concurrent OAuth flows
|
||||
auth_states = {}
|
||||
httpd = None
|
||||
server_thread = None
|
||||
|
||||
@@ -674,26 +677,27 @@ def oauth_callback_server():
|
||||
pass
|
||||
|
||||
def do_GET(self):
|
||||
# Ignore subsequent requests if we already have a code
|
||||
# (this is a session-scoped fixture, so only process the first auth code)
|
||||
if auth_state["code"] is not None:
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "text/html")
|
||||
self.end_headers()
|
||||
self.wfile.write(
|
||||
b"<html><body><h1>Authentication already completed</h1></body></html>"
|
||||
)
|
||||
return
|
||||
|
||||
# Parse the callback request
|
||||
parsed_path = urlparse(self.path)
|
||||
query = parse_qs(parsed_path.query)
|
||||
code = query.get("code", [None])[0]
|
||||
state = query.get("state", [None])[0]
|
||||
|
||||
# Only process if we have a valid code
|
||||
if code:
|
||||
auth_state["code"] = code
|
||||
logger.info(f"OAuth callback received. Code: {code[:20]}...")
|
||||
# Store code keyed by state parameter for correlation
|
||||
if state:
|
||||
auth_states[state] = code
|
||||
logger.info(
|
||||
f"OAuth callback received for state={state[:16]}... Code: {code[:20]}..."
|
||||
)
|
||||
else:
|
||||
# Fallback for flows without state parameter (legacy interactive flow)
|
||||
auth_states["_default"] = code
|
||||
logger.info(
|
||||
f"OAuth callback received (no state). Code: {code[:20]}..."
|
||||
)
|
||||
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "text/html")
|
||||
self.end_headers()
|
||||
@@ -714,8 +718,8 @@ def oauth_callback_server():
|
||||
server_thread.start()
|
||||
logger.info("OAuth callback server started on http://localhost:8081")
|
||||
|
||||
# Yield the auth state and server URL
|
||||
yield auth_state, "http://localhost:8081"
|
||||
# Yield the auth states dict and server URL
|
||||
yield auth_states, "http://localhost:8081"
|
||||
|
||||
finally:
|
||||
# Clean up the server
|
||||
@@ -746,8 +750,8 @@ async def interactive_oauth_token(oauth_callback_server) -> str:
|
||||
|
||||
from nextcloud_mcp_server.auth.client_registration import load_or_register_client
|
||||
|
||||
# Unpack the server fixture
|
||||
auth_state, callback_url = oauth_callback_server
|
||||
# Unpack the server fixture (now returns dict of auth_states)
|
||||
auth_states, callback_url = oauth_callback_server
|
||||
|
||||
nextcloud_host = os.getenv("NEXTCLOUD_HOST")
|
||||
async with httpx.AsyncClient() as http_client:
|
||||
@@ -771,22 +775,22 @@ async def interactive_oauth_token(oauth_callback_server) -> str:
|
||||
"After logging in, the OAuth authorization will proceed automatically"
|
||||
)
|
||||
|
||||
# Construct authorization URL
|
||||
# Construct authorization URL (no state parameter for interactive flow)
|
||||
auth_url = f"{authorization_endpoint}?response_type=code&client_id={client_info.client_id}&redirect_uri={callback_url}&scope=openid%20profile%20email"
|
||||
|
||||
# Open authorization URL in browser
|
||||
webbrowser.open(auth_url)
|
||||
|
||||
# Wait for auth code with timeout
|
||||
# Wait for auth code with timeout (uses "_default" key for flows without state)
|
||||
timeout = 120 # 2 minutes
|
||||
start_time = time.time()
|
||||
while not auth_state["code"]:
|
||||
while "_default" not in auth_states:
|
||||
if time.time() - start_time > timeout:
|
||||
raise TimeoutError("OAuth authorization timed out after 2 minutes")
|
||||
logger.info("Waiting for OAuth authorization...")
|
||||
time.sleep(1)
|
||||
|
||||
auth_code = auth_state["code"]
|
||||
auth_code = auth_states["_default"]
|
||||
logger.info("Received authorization code, exchanging for token...")
|
||||
|
||||
token_response = await http_client.post(
|
||||
@@ -809,13 +813,15 @@ async def interactive_oauth_token(oauth_callback_server) -> str:
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def shared_oauth_client_credentials():
|
||||
async def shared_oauth_client_credentials(oauth_callback_server):
|
||||
"""
|
||||
Fixture to obtain shared OAuth client credentials that will be reused for all users.
|
||||
|
||||
This registers a single OAuth client with Nextcloud that matches the MCP server's
|
||||
registration, allowing all test users to authenticate using the same client_id/secret.
|
||||
|
||||
Now uses the real OAuth callback server for reliable token acquisition.
|
||||
|
||||
Returns:
|
||||
Tuple of (client_id, client_secret, callback_url, token_endpoint, authorization_endpoint)
|
||||
"""
|
||||
@@ -825,7 +831,11 @@ async def shared_oauth_client_credentials():
|
||||
if not nextcloud_host:
|
||||
pytest.skip("Shared OAuth client requires NEXTCLOUD_HOST")
|
||||
|
||||
# Get callback URL from the real callback server
|
||||
auth_states, callback_url = oauth_callback_server
|
||||
|
||||
logger.info("Setting up shared OAuth client credentials for all test users...")
|
||||
logger.info(f"Using real callback server at: {callback_url}")
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as http_client:
|
||||
# OIDC Discovery
|
||||
@@ -841,9 +851,6 @@ async def shared_oauth_client_credentials():
|
||||
if not all([token_endpoint, registration_endpoint, authorization_endpoint]):
|
||||
raise ValueError("OIDC discovery missing required endpoints")
|
||||
|
||||
# Use callback URL that won't actually be used (we extract code from browser URL)
|
||||
callback_url = "http://localhost:9999/oauth/callback"
|
||||
|
||||
# Register or load shared OAuth client (matches MCP server registration)
|
||||
client_info = await load_or_register_client(
|
||||
nextcloud_url=nextcloud_host,
|
||||
@@ -866,7 +873,9 @@ async def shared_oauth_client_credentials():
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def playwright_oauth_token(browser, shared_oauth_client_credentials) -> str:
|
||||
async def playwright_oauth_token(
|
||||
browser, shared_oauth_client_credentials, oauth_callback_server
|
||||
) -> str:
|
||||
"""
|
||||
Fixture to obtain an OAuth access token using Playwright headless browser automation.
|
||||
|
||||
@@ -875,7 +884,7 @@ async def playwright_oauth_token(browser, shared_oauth_client_credentials) -> st
|
||||
2. Navigating to authorization URL in headless browser
|
||||
3. Programmatically filling in login form
|
||||
4. Handling OAuth consent
|
||||
5. Extracting auth code from redirect
|
||||
5. Waiting for callback server to receive auth code (NEW: using real callback server!)
|
||||
6. Exchanging code for access token
|
||||
|
||||
Environment variables required:
|
||||
@@ -888,7 +897,9 @@ async def playwright_oauth_token(browser, shared_oauth_client_credentials) -> st
|
||||
- Browser fixture provided by pytest-playwright-asyncio
|
||||
- See: https://playwright.dev/python/docs/test-runners
|
||||
"""
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
import secrets
|
||||
import time
|
||||
from urllib.parse import quote
|
||||
|
||||
nextcloud_host = os.getenv("NEXTCLOUD_HOST")
|
||||
username = os.getenv("NEXTCLOUD_USERNAME")
|
||||
@@ -899,6 +910,9 @@ async def playwright_oauth_token(browser, shared_oauth_client_credentials) -> st
|
||||
"Playwright OAuth requires NEXTCLOUD_HOST, NEXTCLOUD_USERNAME, and NEXTCLOUD_PASSWORD"
|
||||
)
|
||||
|
||||
# Get auth_states dict from callback server
|
||||
auth_states, _ = oauth_callback_server
|
||||
|
||||
# Unpack shared client credentials
|
||||
client_id, client_secret, callback_url, token_endpoint, authorization_endpoint = (
|
||||
shared_oauth_client_credentials
|
||||
@@ -906,13 +920,19 @@ async def playwright_oauth_token(browser, shared_oauth_client_credentials) -> st
|
||||
|
||||
logger.info(f"Starting Playwright-based OAuth flow for {username}...")
|
||||
logger.info(f"Using shared OAuth client: {client_id[:16]}...")
|
||||
logger.info(f"Using real callback server at: {callback_url}")
|
||||
|
||||
# Construct authorization URL
|
||||
# Generate unique state parameter for this OAuth flow
|
||||
state = secrets.token_urlsafe(32)
|
||||
logger.debug(f"Generated state: {state[:16]}...")
|
||||
|
||||
# Construct authorization URL with state parameter
|
||||
auth_url = (
|
||||
f"{authorization_endpoint}?"
|
||||
f"response_type=code&"
|
||||
f"client_id={client_id}&"
|
||||
f"redirect_uri={callback_url}&"
|
||||
f"redirect_uri={quote(callback_url, safe='')}&"
|
||||
f"state={state}&"
|
||||
f"scope=openid%20profile%20email"
|
||||
)
|
||||
|
||||
@@ -969,33 +989,24 @@ async def playwright_oauth_token(browser, shared_oauth_client_credentials) -> st
|
||||
except Exception as e:
|
||||
logger.debug(f"No authorization button found or already authorized: {e}")
|
||||
|
||||
# Wait for redirect to callback URL (which will fail to load, but we just need the URL)
|
||||
try:
|
||||
# The redirect might fail since localhost:9999 isn't actually running
|
||||
# But we can still extract the code from the URL
|
||||
await page.wait_for_url(f"{callback_url}*", timeout=10000)
|
||||
except Exception as e:
|
||||
# Expected - the callback URL won't load, but we should have the URL
|
||||
logger.debug(f"Callback redirect (expected to fail): {e}")
|
||||
# Wait for callback server to receive the auth code
|
||||
# Browser will be redirected to localhost:8081 which will capture the code
|
||||
logger.info("Waiting for callback server to receive auth code...")
|
||||
timeout_seconds = 30
|
||||
start_time = time.time()
|
||||
while state not in auth_states:
|
||||
if time.time() - start_time > timeout_seconds:
|
||||
# Take a screenshot for debugging
|
||||
screenshot_path = "/tmp/playwright_oauth_error.png"
|
||||
await page.screenshot(path=screenshot_path)
|
||||
logger.error(f"Screenshot saved to {screenshot_path}")
|
||||
raise TimeoutError(
|
||||
f"Timeout waiting for OAuth callback (state={state[:16]}...)"
|
||||
)
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Extract auth code from URL
|
||||
final_url = page.url
|
||||
logger.debug(f"Final URL: {final_url}")
|
||||
|
||||
parsed_url = urlparse(final_url)
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
auth_code = query_params.get("code", [None])[0]
|
||||
|
||||
if not auth_code:
|
||||
# Take a screenshot for debugging
|
||||
screenshot_path = "/tmp/playwright_oauth_error.png"
|
||||
await page.screenshot(path=screenshot_path)
|
||||
logger.error(f"Screenshot saved to {screenshot_path}")
|
||||
raise ValueError(
|
||||
f"No authorization code found in redirect URL: {final_url}"
|
||||
)
|
||||
|
||||
logger.info(f"Successfully extracted authorization code: {auth_code[:20]}...")
|
||||
auth_code = auth_states[state]
|
||||
logger.info(f"Successfully received authorization code: {auth_code[:20]}...")
|
||||
|
||||
finally:
|
||||
await context.close()
|
||||
@@ -1234,23 +1245,31 @@ async def test_users_setup(nc_client: NextcloudClient):
|
||||
|
||||
|
||||
async def _get_oauth_token_for_user(
|
||||
browser, shared_oauth_client_credentials, username: str, password: str
|
||||
browser,
|
||||
shared_oauth_client_credentials,
|
||||
auth_states,
|
||||
username: str,
|
||||
password: str,
|
||||
) -> str:
|
||||
"""
|
||||
Helper function to get OAuth access token for a user via Playwright.
|
||||
|
||||
Uses shared OAuth client credentials to authenticate multiple users with the same client.
|
||||
Now uses real callback server with state parameters for reliable token acquisition.
|
||||
|
||||
Args:
|
||||
browser: Playwright browser instance
|
||||
shared_oauth_client_credentials: Tuple of (client_id, client_secret, callback_url, token_endpoint, authorization_endpoint)
|
||||
auth_states: Dict mapping state parameters to auth codes (from callback server)
|
||||
username: Username to authenticate as
|
||||
password: Password for the user
|
||||
|
||||
Returns:
|
||||
OAuth access token string
|
||||
"""
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
import secrets
|
||||
import time
|
||||
from urllib.parse import quote
|
||||
|
||||
nextcloud_host = os.getenv("NEXTCLOUD_HOST")
|
||||
|
||||
@@ -1265,14 +1284,17 @@ async def _get_oauth_token_for_user(
|
||||
logger.info(f"Getting OAuth token for user: {username}...")
|
||||
logger.info(f"Using shared OAuth client: {client_id[:16]}...")
|
||||
|
||||
# Construct authorization URL with properly encoded redirect_uri
|
||||
from urllib.parse import quote
|
||||
# Generate unique state parameter for this OAuth flow
|
||||
state = secrets.token_urlsafe(32)
|
||||
logger.debug(f"Generated state for {username}: {state[:16]}...")
|
||||
|
||||
# Construct authorization URL with state parameter
|
||||
auth_url = (
|
||||
f"{authorization_endpoint}?"
|
||||
f"response_type=code&"
|
||||
f"client_id={client_id}&"
|
||||
f"redirect_uri={quote(callback_url, safe='')}&"
|
||||
f"state={state}&"
|
||||
f"scope=openid%20profile%20email"
|
||||
)
|
||||
|
||||
@@ -1309,22 +1331,25 @@ async def _get_oauth_token_for_user(
|
||||
except Exception as e:
|
||||
logger.debug(f"No authorization needed for {username}: {e}")
|
||||
|
||||
# Wait for redirect and extract auth code
|
||||
try:
|
||||
await page.wait_for_url(f"{callback_url}*", timeout=30000)
|
||||
except Exception:
|
||||
pass # Expected - callback won't load
|
||||
|
||||
final_url = page.url
|
||||
parsed_url = urlparse(final_url)
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
auth_code = query_params.get("code", [None])[0]
|
||||
|
||||
if not auth_code:
|
||||
raise ValueError(
|
||||
f"No authorization code found for {username} in URL: {final_url}"
|
||||
)
|
||||
# Wait for callback server to receive the auth code
|
||||
# Browser will be redirected to localhost:8081 which will capture the code
|
||||
logger.info(
|
||||
f"Waiting for callback server to receive auth code for {username}..."
|
||||
)
|
||||
timeout_seconds = 30
|
||||
start_time = time.time()
|
||||
while state not in auth_states:
|
||||
if time.time() - start_time > timeout_seconds:
|
||||
# Take screenshot for debugging
|
||||
screenshot_path = f"/tmp/playwright_oauth_timeout_{username}.png"
|
||||
await page.screenshot(path=screenshot_path)
|
||||
logger.error(f"Screenshot saved to {screenshot_path}")
|
||||
raise TimeoutError(
|
||||
f"Timeout waiting for OAuth callback for {username} (state={state[:16]}...)"
|
||||
)
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
auth_code = auth_states[state]
|
||||
logger.info(f"Got auth code for {username}: {auth_code[:20]}...")
|
||||
|
||||
finally:
|
||||
@@ -1358,7 +1383,7 @@ async def _get_oauth_token_for_user(
|
||||
# Parallel token retrieval fixture - fetches all OAuth tokens concurrently
|
||||
@pytest.fixture(scope="session")
|
||||
async def all_oauth_tokens(
|
||||
browser, shared_oauth_client_credentials, test_users_setup
|
||||
browser, shared_oauth_client_credentials, test_users_setup, oauth_callback_server
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Fetch OAuth tokens for all test users in parallel for speed.
|
||||
@@ -1366,26 +1391,34 @@ async def all_oauth_tokens(
|
||||
Returns a dict mapping username to OAuth access token.
|
||||
This is significantly faster than fetching tokens sequentially.
|
||||
|
||||
Note: We add a small stagger between starting each flow to avoid
|
||||
race conditions in Nextcloud's OAuth session handling.
|
||||
Now uses the real callback server with state parameters for reliable
|
||||
concurrent token acquisition without race conditions.
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
# Get auth_states dict from callback server
|
||||
auth_states, callback_url = oauth_callback_server
|
||||
|
||||
start_time = time.time()
|
||||
logger.info("Fetching OAuth tokens for all users in parallel...")
|
||||
logger.info(f"Using callback server at {callback_url} with state-based correlation")
|
||||
|
||||
async def get_token_with_delay(username: str, config: dict, delay: float):
|
||||
"""Get token for a user after a small delay to stagger requests."""
|
||||
if delay > 0:
|
||||
await asyncio.sleep(delay)
|
||||
return await _get_oauth_token_for_user(
|
||||
browser, shared_oauth_client_credentials, username, config["password"]
|
||||
browser,
|
||||
shared_oauth_client_credentials,
|
||||
auth_states,
|
||||
username,
|
||||
config["password"],
|
||||
)
|
||||
|
||||
# Create tasks for all users with staggered starts (2.0s apart)
|
||||
tasks = {
|
||||
username: get_token_with_delay(username, config, (idx + 1) * 2.0)
|
||||
username: get_token_with_delay(username, config, idx * 0.5)
|
||||
for idx, (username, config) in enumerate(test_users_setup.items())
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user