feat: Implement ADR-004 Progressive Consent foundation components
- Token Broker Service manages Nextcloud access tokens with audience validation - Implements short-lived token caching (5-minute TTL) with early refresh - Enhanced token storage schema with ADR-004 fields (flow_type, audience, provisioning) - MCP provisioning tools for explicit Flow 2 resource authorization - Comprehensive unit tests for Token Broker Service (14 tests, all passing) - Environment configuration for Progressive Consent mode This implements the foundation for the dual OAuth flow architecture where: - Flow 1: MCP clients authenticate to MCP server (aud: "mcp-server") - Flow 2: MCP server gets delegated Nextcloud access (aud: "nextcloud") Users must explicitly call provision_nextcloud_access tool to grant resource access, implementing the "stateless by default" principle from ADR-004. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
+23
@@ -21,6 +21,29 @@ NEXTCLOUD_MCP_SERVER_URL=http://localhost:8000
|
||||
# TOKEN_STORAGE_DB: Path to SQLite database (default: /app/data/tokens.db)
|
||||
#TOKEN_STORAGE_DB=/app/data/tokens.db
|
||||
|
||||
# ===== ADR-004 PROGRESSIVE CONSENT CONFIGURATION =====
|
||||
# Enable Progressive Consent mode (dual OAuth flows)
|
||||
# When enabled: Flow 1 for client auth, Flow 2 for Nextcloud resource access
|
||||
# When disabled: Uses existing hybrid flow (backward compatible)
|
||||
#ENABLE_PROGRESSIVE_CONSENT=false
|
||||
|
||||
# MCP Server OAuth Client Configuration
|
||||
# The MCP server's own OAuth client credentials for Flow 2
|
||||
# If not set, will use dynamic client registration
|
||||
#MCP_SERVER_CLIENT_ID=
|
||||
#MCP_SERVER_CLIENT_SECRET=
|
||||
|
||||
# Allowed MCP Client IDs (comma-separated list)
|
||||
# Client IDs that are allowed to authenticate in Flow 1
|
||||
# Examples: claude-desktop,continue-dev,zed-editor
|
||||
#ALLOWED_MCP_CLIENTS=claude-desktop,continue-dev,zed-editor
|
||||
|
||||
# Token cache configuration for Token Broker Service
|
||||
# Cache TTL in seconds (default: 300 = 5 minutes)
|
||||
#TOKEN_CACHE_TTL=300
|
||||
# Early refresh threshold in seconds (default: 30)
|
||||
#TOKEN_CACHE_EARLY_REFRESH=30
|
||||
|
||||
# Option 2: Basic Authentication (LEGACY - Less Secure)
|
||||
# - Requires username and password
|
||||
# - Credentials stored in environment variables
|
||||
|
||||
@@ -98,7 +98,13 @@ class RefreshTokenStorage:
|
||||
encrypted_token BLOB NOT NULL,
|
||||
expires_at INTEGER,
|
||||
created_at INTEGER NOT NULL,
|
||||
updated_at INTEGER NOT NULL
|
||||
updated_at INTEGER NOT NULL,
|
||||
-- ADR-004 Progressive Consent fields
|
||||
flow_type TEXT DEFAULT 'hybrid', -- 'hybrid', 'flow1', 'flow2'
|
||||
token_audience TEXT DEFAULT 'nextcloud', -- 'mcp-server' or 'nextcloud'
|
||||
provisioned_at INTEGER, -- When Flow 2 was completed
|
||||
provisioning_client_id TEXT, -- Which MCP client initiated Flow 1
|
||||
scopes TEXT -- JSON array of granted scopes
|
||||
)
|
||||
"""
|
||||
)
|
||||
@@ -142,7 +148,7 @@ class RefreshTokenStorage:
|
||||
"""
|
||||
)
|
||||
|
||||
# OAuth flow sessions (ADR-004 Hybrid Flow)
|
||||
# OAuth flow sessions (ADR-004 Progressive Consent)
|
||||
await db.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS oauth_sessions (
|
||||
@@ -157,7 +163,12 @@ class RefreshTokenStorage:
|
||||
idp_refresh_token TEXT,
|
||||
user_id TEXT,
|
||||
created_at INTEGER NOT NULL,
|
||||
expires_at INTEGER NOT NULL
|
||||
expires_at INTEGER NOT NULL,
|
||||
-- ADR-004 Progressive Consent fields
|
||||
flow_type TEXT DEFAULT 'hybrid', -- 'hybrid', 'flow1', 'flow2'
|
||||
requested_scopes TEXT, -- JSON array of requested scopes
|
||||
granted_scopes TEXT, -- JSON array of granted scopes
|
||||
is_provisioning BOOLEAN DEFAULT FALSE -- True if this is a Flow 2 provisioning session
|
||||
)
|
||||
"""
|
||||
)
|
||||
@@ -181,6 +192,10 @@ class RefreshTokenStorage:
|
||||
user_id: str,
|
||||
refresh_token: str,
|
||||
expires_at: Optional[int] = None,
|
||||
flow_type: str = "hybrid",
|
||||
token_audience: str = "nextcloud",
|
||||
provisioning_client_id: Optional[str] = None,
|
||||
scopes: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Store encrypted refresh token for user.
|
||||
@@ -189,6 +204,10 @@ class RefreshTokenStorage:
|
||||
user_id: User identifier (from OIDC 'sub' claim)
|
||||
refresh_token: Refresh token to store
|
||||
expires_at: Token expiration timestamp (Unix epoch), if known
|
||||
flow_type: Type of flow ('hybrid', 'flow1', 'flow2')
|
||||
token_audience: Token audience ('mcp-server' or 'nextcloud')
|
||||
provisioning_client_id: Client ID that initiated Flow 1
|
||||
scopes: List of granted scopes
|
||||
|
||||
"""
|
||||
if not self._initialized:
|
||||
@@ -196,15 +215,33 @@ class RefreshTokenStorage:
|
||||
|
||||
encrypted_token = self.cipher.encrypt(refresh_token.encode())
|
||||
now = int(time.time())
|
||||
scopes_json = json.dumps(scopes) if scopes else None
|
||||
|
||||
# For Flow 2, set provisioned_at timestamp
|
||||
provisioned_at = now if flow_type == "flow2" else None
|
||||
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO refresh_tokens
|
||||
(user_id, encrypted_token, expires_at, created_at, updated_at)
|
||||
VALUES (?, ?, ?, COALESCE((SELECT created_at FROM refresh_tokens WHERE user_id = ?), ?), ?)
|
||||
(user_id, encrypted_token, expires_at, created_at, updated_at,
|
||||
flow_type, token_audience, provisioned_at, provisioning_client_id, scopes)
|
||||
VALUES (?, ?, ?, COALESCE((SELECT created_at FROM refresh_tokens WHERE user_id = ?), ?), ?,
|
||||
?, ?, ?, ?, ?)
|
||||
""",
|
||||
(user_id, encrypted_token, expires_at, user_id, now, now),
|
||||
(
|
||||
user_id,
|
||||
encrypted_token,
|
||||
expires_at,
|
||||
user_id,
|
||||
now,
|
||||
now,
|
||||
flow_type,
|
||||
token_audience,
|
||||
provisioned_at,
|
||||
provisioning_client_id,
|
||||
scopes_json,
|
||||
),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
@@ -220,7 +257,7 @@ class RefreshTokenStorage:
|
||||
auth_method="offline_access",
|
||||
)
|
||||
|
||||
async def get_refresh_token(self, user_id: str) -> Optional[str]:
|
||||
async def get_refresh_token(self, user_id: str) -> Optional[dict]:
|
||||
"""
|
||||
Retrieve and decrypt refresh token for user.
|
||||
|
||||
@@ -228,14 +265,28 @@ class RefreshTokenStorage:
|
||||
user_id: User identifier
|
||||
|
||||
Returns:
|
||||
Decrypted refresh token, or None if not found or expired
|
||||
Dictionary with token data including ADR-004 fields:
|
||||
{
|
||||
"refresh_token": str,
|
||||
"expires_at": int | None,
|
||||
"flow_type": str,
|
||||
"token_audience": str,
|
||||
"provisioned_at": int | None,
|
||||
"provisioning_client_id": str | None,
|
||||
"scopes": list[str] | None
|
||||
}
|
||||
or None if not found or expired
|
||||
"""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
async with db.execute(
|
||||
"SELECT encrypted_token, expires_at FROM refresh_tokens WHERE user_id = ?",
|
||||
"""
|
||||
SELECT encrypted_token, expires_at, flow_type, token_audience,
|
||||
provisioned_at, provisioning_client_id, scopes
|
||||
FROM refresh_tokens WHERE user_id = ?
|
||||
""",
|
||||
(user_id,),
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
@@ -244,7 +295,15 @@ class RefreshTokenStorage:
|
||||
logger.debug(f"No refresh token found for user {user_id}")
|
||||
return None
|
||||
|
||||
encrypted_token, expires_at = row
|
||||
(
|
||||
encrypted_token,
|
||||
expires_at,
|
||||
flow_type,
|
||||
token_audience,
|
||||
provisioned_at,
|
||||
provisioning_client_id,
|
||||
scopes_json,
|
||||
) = row
|
||||
|
||||
# Check expiration
|
||||
if expires_at is not None and expires_at < time.time():
|
||||
@@ -256,8 +315,22 @@ class RefreshTokenStorage:
|
||||
|
||||
try:
|
||||
decrypted_token = self.cipher.decrypt(encrypted_token).decode()
|
||||
logger.debug(f"Retrieved refresh token for user {user_id}")
|
||||
return decrypted_token
|
||||
scopes = json.loads(scopes_json) if scopes_json else None
|
||||
|
||||
logger.debug(
|
||||
f"Retrieved refresh token for user {user_id} (flow_type: {flow_type})"
|
||||
)
|
||||
|
||||
return {
|
||||
"refresh_token": decrypted_token,
|
||||
"expires_at": expires_at,
|
||||
"flow_type": flow_type or "hybrid", # Default for existing tokens
|
||||
"token_audience": token_audience
|
||||
or "nextcloud", # Default for existing tokens
|
||||
"provisioned_at": provisioned_at,
|
||||
"provisioning_client_id": provisioning_client_id,
|
||||
"scopes": scopes,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decrypt refresh token for user {user_id}: {e}")
|
||||
return None
|
||||
|
||||
@@ -0,0 +1,419 @@
|
||||
"""
|
||||
Token Broker Service for ADR-004 Progressive Consent Architecture.
|
||||
|
||||
This service manages the lifecycle of Nextcloud access tokens, implementing
|
||||
the dual OAuth flow pattern where:
|
||||
1. MCP clients authenticate to MCP server with aud:"mcp-server" tokens
|
||||
2. MCP server uses stored refresh tokens to obtain aud:"nextcloud" tokens
|
||||
|
||||
The Token Broker provides:
|
||||
- Automatic token refresh when expired
|
||||
- Short-lived token caching (5-minute TTL)
|
||||
- Master refresh token rotation
|
||||
- Audience-specific token validation
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
import jwt
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
from nextcloud_mcp_server.auth.refresh_token_storage import RefreshTokenStorage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TokenCache:
|
||||
"""In-memory cache for short-lived Nextcloud access tokens."""
|
||||
|
||||
def __init__(self, ttl_seconds: int = 300, early_refresh_seconds: int = 30):
|
||||
"""
|
||||
Initialize the token cache.
|
||||
|
||||
Args:
|
||||
ttl_seconds: Default TTL for cached tokens (5 minutes default)
|
||||
early_refresh_seconds: How many seconds before expiry to trigger early refresh (30s default)
|
||||
"""
|
||||
self._cache: Dict[str, Tuple[str, datetime]] = {}
|
||||
self._ttl = timedelta(seconds=ttl_seconds)
|
||||
self._early_refresh = timedelta(seconds=early_refresh_seconds)
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def get(self, user_id: str) -> Optional[str]:
|
||||
"""Get cached token if valid."""
|
||||
async with self._lock:
|
||||
if user_id not in self._cache:
|
||||
return None
|
||||
|
||||
token, expiry = self._cache[user_id]
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Check if token has expired
|
||||
if now >= expiry:
|
||||
del self._cache[user_id]
|
||||
logger.debug(f"Cached token expired for user {user_id}")
|
||||
return None
|
||||
|
||||
# Check if token will expire soon (refresh early)
|
||||
if now >= expiry - self._early_refresh:
|
||||
logger.debug(f"Cached token expiring soon for user {user_id}")
|
||||
return None
|
||||
|
||||
logger.debug(f"Using cached token for user {user_id}")
|
||||
return token
|
||||
|
||||
async def set(self, user_id: str, token: str, expires_in: int = None):
|
||||
"""Store token in cache."""
|
||||
async with self._lock:
|
||||
# Use provided expiry or default TTL
|
||||
if expires_in:
|
||||
expiry = datetime.now(timezone.utc) + timedelta(seconds=expires_in)
|
||||
else:
|
||||
expiry = datetime.now(timezone.utc) + self._ttl
|
||||
|
||||
self._cache[user_id] = (token, expiry)
|
||||
logger.debug(f"Cached token for user {user_id} until {expiry}")
|
||||
|
||||
async def invalidate(self, user_id: str):
|
||||
"""Remove token from cache."""
|
||||
async with self._lock:
|
||||
if user_id in self._cache:
|
||||
del self._cache[user_id]
|
||||
logger.debug(f"Invalidated cached token for user {user_id}")
|
||||
|
||||
|
||||
class TokenBrokerService:
|
||||
"""
|
||||
Manages token lifecycle for the Progressive Consent architecture.
|
||||
|
||||
This service handles:
|
||||
- Getting or refreshing Nextcloud access tokens
|
||||
- Managing a short-lived token cache
|
||||
- Refreshing master refresh tokens periodically
|
||||
- Validating token audiences
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage: RefreshTokenStorage,
|
||||
oidc_discovery_url: str,
|
||||
nextcloud_host: str,
|
||||
encryption_key: str,
|
||||
cache_ttl: int = 300,
|
||||
cache_early_refresh: int = 30,
|
||||
):
|
||||
"""
|
||||
Initialize the Token Broker Service.
|
||||
|
||||
Args:
|
||||
storage: Database storage for refresh tokens
|
||||
oidc_discovery_url: OIDC provider discovery URL
|
||||
nextcloud_host: Nextcloud server URL
|
||||
encryption_key: Fernet key for token encryption
|
||||
cache_ttl: Cache TTL in seconds (default: 5 minutes)
|
||||
cache_early_refresh: Early refresh threshold in seconds (default: 30 seconds)
|
||||
"""
|
||||
self.storage = storage
|
||||
self.oidc_discovery_url = oidc_discovery_url
|
||||
self.nextcloud_host = nextcloud_host
|
||||
self.fernet = Fernet(
|
||||
encryption_key.encode()
|
||||
if isinstance(encryption_key, str)
|
||||
else encryption_key
|
||||
)
|
||||
self.cache = TokenCache(cache_ttl, cache_early_refresh)
|
||||
self._oidc_config = None
|
||||
self._http_client = None
|
||||
|
||||
async def _get_http_client(self) -> httpx.AsyncClient:
|
||||
"""Get or create HTTP client."""
|
||||
if self._http_client is None:
|
||||
self._http_client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(30.0), follow_redirects=True
|
||||
)
|
||||
return self._http_client
|
||||
|
||||
async def _get_oidc_config(self) -> dict:
|
||||
"""Get OIDC configuration from discovery endpoint."""
|
||||
if self._oidc_config is None:
|
||||
client = await self._get_http_client()
|
||||
response = await client.get(self.oidc_discovery_url)
|
||||
response.raise_for_status()
|
||||
self._oidc_config = response.json()
|
||||
return self._oidc_config
|
||||
|
||||
async def get_nextcloud_token(self, user_id: str) -> Optional[str]:
|
||||
"""
|
||||
Get a valid Nextcloud access token for the user.
|
||||
|
||||
This method:
|
||||
1. Checks the cache for a valid token
|
||||
2. If not cached, checks for stored refresh token
|
||||
3. If refresh token exists, obtains new access token
|
||||
4. Caches the new token for future requests
|
||||
|
||||
Args:
|
||||
user_id: The user identifier
|
||||
|
||||
Returns:
|
||||
Valid Nextcloud access token or None if not provisioned
|
||||
"""
|
||||
# Check cache first
|
||||
cached_token = await self.cache.get(user_id)
|
||||
if cached_token:
|
||||
return cached_token
|
||||
|
||||
# Get stored refresh token
|
||||
refresh_data = await self.storage.get_refresh_token(user_id)
|
||||
if not refresh_data:
|
||||
logger.info(f"No refresh token found for user {user_id}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Decrypt refresh token
|
||||
encrypted_token = refresh_data["refresh_token"]
|
||||
refresh_token = self.fernet.decrypt(encrypted_token.encode()).decode()
|
||||
|
||||
# Exchange refresh token for new access token
|
||||
access_token, expires_in = await self._refresh_access_token(refresh_token)
|
||||
|
||||
# Cache the new token
|
||||
await self.cache.set(user_id, access_token, expires_in)
|
||||
|
||||
return access_token
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get Nextcloud token for user {user_id}: {e}")
|
||||
# Invalidate cache on error
|
||||
await self.cache.invalidate(user_id)
|
||||
return None
|
||||
|
||||
async def _refresh_access_token(self, refresh_token: str) -> Tuple[str, int]:
|
||||
"""
|
||||
Exchange refresh token for new access token.
|
||||
|
||||
Args:
|
||||
refresh_token: The refresh token
|
||||
|
||||
Returns:
|
||||
Tuple of (access_token, expires_in_seconds)
|
||||
"""
|
||||
config = await self._get_oidc_config()
|
||||
token_endpoint = config["token_endpoint"]
|
||||
|
||||
client = await self._get_http_client()
|
||||
|
||||
# Request new access token using refresh token
|
||||
data = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"scope": "openid profile email notes:read notes:write calendar:read calendar:write",
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
token_endpoint,
|
||||
data=data,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(
|
||||
f"Token refresh failed: {response.status_code} - {response.text}"
|
||||
)
|
||||
raise Exception(f"Token refresh failed: {response.status_code}")
|
||||
|
||||
token_data = response.json()
|
||||
access_token = token_data["access_token"]
|
||||
expires_in = token_data.get("expires_in", 3600) # Default 1 hour
|
||||
|
||||
# Validate audience
|
||||
await self._validate_token_audience(access_token, "nextcloud")
|
||||
|
||||
logger.info(f"Refreshed access token (expires in {expires_in}s)")
|
||||
return access_token, expires_in
|
||||
|
||||
async def _validate_token_audience(self, token: str, expected_audience: str):
|
||||
"""
|
||||
Validate that token has correct audience claim.
|
||||
|
||||
Args:
|
||||
token: JWT token to validate
|
||||
expected_audience: Expected audience value
|
||||
|
||||
Raises:
|
||||
ValueError: If audience doesn't match
|
||||
"""
|
||||
try:
|
||||
# Decode without verification to check claims
|
||||
# In production, should verify signature
|
||||
claims = jwt.decode(token, options={"verify_signature": False})
|
||||
|
||||
audience = claims.get("aud", [])
|
||||
if isinstance(audience, str):
|
||||
audience = [audience]
|
||||
|
||||
if expected_audience not in audience:
|
||||
raise ValueError(
|
||||
f"Token audience {audience} doesn't include {expected_audience}"
|
||||
)
|
||||
|
||||
except jwt.DecodeError as e:
|
||||
# Token might be opaque, skip validation
|
||||
logger.debug(f"Cannot decode token for audience validation: {e}")
|
||||
|
||||
async def refresh_master_token(self, user_id: str) -> bool:
|
||||
"""
|
||||
Refresh the master refresh token (periodic rotation).
|
||||
|
||||
This should be called periodically (e.g., daily) to rotate
|
||||
refresh tokens for security.
|
||||
|
||||
Args:
|
||||
user_id: The user identifier
|
||||
|
||||
Returns:
|
||||
True if refresh successful, False otherwise
|
||||
"""
|
||||
refresh_data = await self.storage.get_refresh_token(user_id)
|
||||
if not refresh_data:
|
||||
logger.warning(f"No refresh token to rotate for user {user_id}")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Decrypt current refresh token
|
||||
encrypted_token = refresh_data["refresh_token"]
|
||||
current_refresh_token = self.fernet.decrypt(
|
||||
encrypted_token.encode()
|
||||
).decode()
|
||||
|
||||
# Get OIDC configuration
|
||||
config = await self._get_oidc_config()
|
||||
token_endpoint = config["token_endpoint"]
|
||||
|
||||
client = await self._get_http_client()
|
||||
|
||||
# Request new refresh token
|
||||
data = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": current_refresh_token,
|
||||
"scope": "openid profile email offline_access notes:read notes:write calendar:read calendar:write",
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
token_endpoint,
|
||||
data=data,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Master token refresh failed: {response.status_code}")
|
||||
return False
|
||||
|
||||
token_data = response.json()
|
||||
new_refresh_token = token_data.get("refresh_token")
|
||||
|
||||
if new_refresh_token and new_refresh_token != current_refresh_token:
|
||||
# Encrypt and store new refresh token
|
||||
encrypted_new = self.fernet.encrypt(new_refresh_token.encode()).decode()
|
||||
await self.storage.store_refresh_token(
|
||||
user_id=user_id,
|
||||
refresh_token=encrypted_new,
|
||||
expires_at=datetime.now(timezone.utc)
|
||||
+ timedelta(days=90), # 90-day expiry
|
||||
)
|
||||
logger.info(f"Rotated master refresh token for user {user_id}")
|
||||
|
||||
# Invalidate cached access token
|
||||
await self.cache.invalidate(user_id)
|
||||
return True
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to refresh master token for user {user_id}: {e}")
|
||||
return False
|
||||
|
||||
async def has_nextcloud_provisioning(self, user_id: str) -> bool:
|
||||
"""
|
||||
Check if user has provisioned Nextcloud access (Flow 2).
|
||||
|
||||
Args:
|
||||
user_id: The user identifier
|
||||
|
||||
Returns:
|
||||
True if user has stored refresh token, False otherwise
|
||||
"""
|
||||
refresh_data = await self.storage.get_refresh_token(user_id)
|
||||
return refresh_data is not None
|
||||
|
||||
async def revoke_nextcloud_access(self, user_id: str) -> bool:
|
||||
"""
|
||||
Revoke stored Nextcloud access for a user.
|
||||
|
||||
This removes stored refresh tokens and clears cache.
|
||||
|
||||
Args:
|
||||
user_id: The user identifier
|
||||
|
||||
Returns:
|
||||
True if revocation successful
|
||||
"""
|
||||
try:
|
||||
# Get refresh token for revocation at IdP
|
||||
refresh_data = await self.storage.get_refresh_token(user_id)
|
||||
if refresh_data:
|
||||
try:
|
||||
# Attempt to revoke at IdP
|
||||
encrypted_token = refresh_data["refresh_token"]
|
||||
refresh_token = self.fernet.decrypt(
|
||||
encrypted_token.encode()
|
||||
).decode()
|
||||
await self._revoke_token_at_idp(refresh_token)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to revoke at IdP: {e}")
|
||||
|
||||
# Remove from storage
|
||||
await self.storage.delete_refresh_token(user_id)
|
||||
|
||||
# Clear cache
|
||||
await self.cache.invalidate(user_id)
|
||||
|
||||
logger.info(f"Revoked Nextcloud access for user {user_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to revoke access for user {user_id}: {e}")
|
||||
return False
|
||||
|
||||
async def _revoke_token_at_idp(self, token: str):
|
||||
"""Revoke token at the IdP if revocation endpoint exists."""
|
||||
config = await self._get_oidc_config()
|
||||
revocation_endpoint = config.get("revocation_endpoint")
|
||||
|
||||
if not revocation_endpoint:
|
||||
logger.debug("No revocation endpoint available")
|
||||
return
|
||||
|
||||
client = await self._get_http_client()
|
||||
|
||||
data = {"token": token, "token_type_hint": "refresh_token"}
|
||||
|
||||
response = await client.post(
|
||||
revocation_endpoint,
|
||||
data=data,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
logger.info("Token revoked at IdP")
|
||||
else:
|
||||
logger.warning(f"Token revocation returned {response.status_code}")
|
||||
|
||||
async def close(self):
|
||||
"""Clean up resources."""
|
||||
if self._http_client:
|
||||
await self._http_client.aclose()
|
||||
@@ -0,0 +1,400 @@
|
||||
"""
|
||||
MCP Tools for OAuth and Provisioning Management (ADR-004 Progressive Consent).
|
||||
|
||||
This module provides MCP tools that enable users to explicitly provision
|
||||
Nextcloud access using the Flow 2 (Resource Provisioning) OAuth flow.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from mcp.server.fastmcp import Context
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from nextcloud_mcp_server.auth.refresh_token_storage import RefreshTokenStorage
|
||||
from nextcloud_mcp_server.auth.token_broker import TokenBrokerService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProvisioningStatus(BaseModel):
|
||||
"""Status of Nextcloud provisioning for a user."""
|
||||
|
||||
is_provisioned: bool = Field(description="Whether Nextcloud access is provisioned")
|
||||
provisioned_at: Optional[str] = Field(
|
||||
None, description="ISO timestamp when provisioned"
|
||||
)
|
||||
client_id: Optional[str] = Field(
|
||||
None, description="Client ID that initiated the original Flow 1"
|
||||
)
|
||||
scopes: Optional[list[str]] = Field(None, description="Granted scopes")
|
||||
flow_type: Optional[str] = Field(
|
||||
None, description="Type of flow used ('hybrid', 'flow1', 'flow2')"
|
||||
)
|
||||
|
||||
|
||||
class ProvisioningResult(BaseModel):
|
||||
"""Result of provisioning attempt."""
|
||||
|
||||
success: bool = Field(description="Whether provisioning was initiated")
|
||||
authorization_url: Optional[str] = Field(
|
||||
None, description="URL for user to complete OAuth authorization"
|
||||
)
|
||||
message: str = Field(description="Status message for the user")
|
||||
already_provisioned: bool = Field(
|
||||
False, description="Whether access was already provisioned"
|
||||
)
|
||||
|
||||
|
||||
class RevocationResult(BaseModel):
|
||||
"""Result of access revocation."""
|
||||
|
||||
success: bool = Field(description="Whether revocation succeeded")
|
||||
message: str = Field(description="Status message for the user")
|
||||
|
||||
|
||||
async def get_provisioning_status(mcp: Context, user_id: str) -> ProvisioningStatus:
|
||||
"""
|
||||
Check the provisioning status for Nextcloud access.
|
||||
|
||||
This checks whether the user has completed Flow 2 to provision
|
||||
offline access to Nextcloud resources.
|
||||
|
||||
Args:
|
||||
mcp: MCP context
|
||||
user_id: User identifier
|
||||
|
||||
Returns:
|
||||
ProvisioningStatus with current provisioning state
|
||||
"""
|
||||
storage = RefreshTokenStorage.from_env()
|
||||
await storage.initialize()
|
||||
|
||||
token_data = await storage.get_refresh_token(user_id)
|
||||
|
||||
if not token_data:
|
||||
return ProvisioningStatus(is_provisioned=False)
|
||||
|
||||
# Convert timestamp to ISO format if present
|
||||
provisioned_at_str = None
|
||||
if token_data.get("provisioned_at"):
|
||||
from datetime import datetime, timezone
|
||||
|
||||
dt = datetime.fromtimestamp(token_data["provisioned_at"], tz=timezone.utc)
|
||||
provisioned_at_str = dt.isoformat()
|
||||
|
||||
return ProvisioningStatus(
|
||||
is_provisioned=True,
|
||||
provisioned_at=provisioned_at_str,
|
||||
client_id=token_data.get("provisioning_client_id"),
|
||||
scopes=token_data.get("scopes"),
|
||||
flow_type=token_data.get("flow_type", "hybrid"),
|
||||
)
|
||||
|
||||
|
||||
def generate_oauth_url_for_flow2(
|
||||
oidc_discovery_url: str,
|
||||
server_client_id: str,
|
||||
redirect_uri: str,
|
||||
state: str,
|
||||
scopes: list[str],
|
||||
) -> str:
|
||||
"""
|
||||
Generate OAuth authorization URL for Flow 2 (Resource Provisioning).
|
||||
|
||||
This creates the URL that the MCP server uses to get delegated
|
||||
access to Nextcloud on behalf of the user.
|
||||
|
||||
Args:
|
||||
oidc_discovery_url: OIDC provider discovery URL
|
||||
server_client_id: MCP server's OAuth client ID
|
||||
redirect_uri: Callback URL for the MCP server
|
||||
state: CSRF protection state
|
||||
scopes: List of scopes to request
|
||||
|
||||
Returns:
|
||||
Complete authorization URL for Flow 2
|
||||
"""
|
||||
# Extract base URL from discovery URL
|
||||
# Format: https://example.com/.well-known/openid-configuration
|
||||
# We need: https://example.com/apps/oidc/authorize
|
||||
base_url = oidc_discovery_url.replace("/.well-known/openid-configuration", "")
|
||||
auth_endpoint = f"{base_url}/apps/oidc/authorize"
|
||||
|
||||
# Build OAuth parameters
|
||||
params = {
|
||||
"response_type": "code",
|
||||
"client_id": server_client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"scope": " ".join(scopes),
|
||||
"state": state,
|
||||
# Request offline access for background operations
|
||||
"access_type": "offline",
|
||||
"prompt": "consent", # Force consent screen to show scopes
|
||||
}
|
||||
|
||||
return f"{auth_endpoint}?{urlencode(params)}"
|
||||
|
||||
|
||||
async def provision_nextcloud_access(
|
||||
mcp: Context, user_id: Optional[str] = None
|
||||
) -> ProvisioningResult:
|
||||
"""
|
||||
MCP Tool: Provision offline access to Nextcloud resources.
|
||||
|
||||
This tool initiates Flow 2 of the Progressive Consent architecture,
|
||||
allowing the MCP server to obtain delegated access to Nextcloud APIs.
|
||||
|
||||
The user must complete the OAuth flow in their browser to grant access.
|
||||
|
||||
Args:
|
||||
mcp: MCP context
|
||||
user_id: Optional user identifier (extracted from token if not provided)
|
||||
|
||||
Returns:
|
||||
ProvisioningResult with authorization URL or status
|
||||
"""
|
||||
try:
|
||||
# Get user ID from context if not provided
|
||||
if not user_id:
|
||||
# In a real implementation, extract from the MCP access token
|
||||
user_id = mcp.context.get("user_id", "default_user")
|
||||
|
||||
# Check if already provisioned
|
||||
status = await get_provisioning_status(mcp, user_id)
|
||||
if status.is_provisioned:
|
||||
return ProvisioningResult(
|
||||
success=True,
|
||||
already_provisioned=True,
|
||||
message=(
|
||||
f"Nextcloud access is already provisioned (since {status.provisioned_at}). "
|
||||
"Use 'revoke_nextcloud_access' if you want to re-provision."
|
||||
),
|
||||
)
|
||||
|
||||
# Get configuration
|
||||
enable_progressive = (
|
||||
os.getenv("ENABLE_PROGRESSIVE_CONSENT", "false").lower() == "true"
|
||||
)
|
||||
if not enable_progressive:
|
||||
return ProvisioningResult(
|
||||
success=False,
|
||||
message=(
|
||||
"Progressive Consent is not enabled. "
|
||||
"Set ENABLE_PROGRESSIVE_CONSENT=true to use this feature."
|
||||
),
|
||||
)
|
||||
|
||||
# Get MCP server's OAuth client credentials
|
||||
server_client_id = os.getenv("MCP_SERVER_CLIENT_ID")
|
||||
if not server_client_id:
|
||||
# In production, would use Dynamic Client Registration here
|
||||
return ProvisioningResult(
|
||||
success=False,
|
||||
message=(
|
||||
"MCP server OAuth client not configured. "
|
||||
"Administrator must set MCP_SERVER_CLIENT_ID."
|
||||
),
|
||||
)
|
||||
|
||||
# Generate OAuth URL for Flow 2
|
||||
oidc_discovery_url = os.getenv(
|
||||
"OIDC_DISCOVERY_URL",
|
||||
f"{os.getenv('NEXTCLOUD_HOST')}/.well-known/openid-configuration",
|
||||
)
|
||||
|
||||
# Generate secure state for CSRF protection
|
||||
state = secrets.token_urlsafe(32)
|
||||
|
||||
# Store state in session for validation on callback
|
||||
storage = RefreshTokenStorage.from_env()
|
||||
await storage.initialize()
|
||||
|
||||
# Create OAuth session for Flow 2
|
||||
session_id = f"flow2_{user_id}_{secrets.token_hex(8)}"
|
||||
redirect_uri = f"{os.getenv('NEXTCLOUD_MCP_SERVER_URL', 'http://localhost:8000')}/oauth/callback-nextcloud"
|
||||
|
||||
await storage.store_oauth_session(
|
||||
session_id=session_id,
|
||||
client_redirect_uri="", # No client redirect for Flow 2
|
||||
state=state,
|
||||
flow_type="flow2",
|
||||
is_provisioning=True,
|
||||
ttl_seconds=600, # 10 minute TTL
|
||||
)
|
||||
|
||||
# Define scopes for Nextcloud access
|
||||
scopes = [
|
||||
"openid",
|
||||
"profile",
|
||||
"email",
|
||||
"offline_access", # Critical for background operations
|
||||
"notes:read",
|
||||
"notes:write",
|
||||
"calendar:read",
|
||||
"calendar:write",
|
||||
"contacts:read",
|
||||
"contacts:write",
|
||||
"files:read",
|
||||
"files:write",
|
||||
]
|
||||
|
||||
# Generate authorization URL
|
||||
auth_url = generate_oauth_url_for_flow2(
|
||||
oidc_discovery_url=oidc_discovery_url,
|
||||
server_client_id=server_client_id,
|
||||
redirect_uri=redirect_uri,
|
||||
state=state,
|
||||
scopes=scopes,
|
||||
)
|
||||
|
||||
return ProvisioningResult(
|
||||
success=True,
|
||||
authorization_url=auth_url,
|
||||
message=(
|
||||
"Please visit the authorization URL to grant the MCP server "
|
||||
"offline access to your Nextcloud resources. This is a one-time "
|
||||
"setup that allows the server to access Nextcloud on your behalf "
|
||||
"even when you're not actively connected."
|
||||
),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initiate provisioning: {e}")
|
||||
return ProvisioningResult(
|
||||
success=False,
|
||||
message=f"Failed to initiate provisioning: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
async def revoke_nextcloud_access(
|
||||
mcp: Context, user_id: Optional[str] = None
|
||||
) -> RevocationResult:
|
||||
"""
|
||||
MCP Tool: Revoke offline access to Nextcloud resources.
|
||||
|
||||
This tool removes the stored refresh token and revokes access
|
||||
that was granted via Flow 2.
|
||||
|
||||
Args:
|
||||
mcp: MCP context
|
||||
user_id: Optional user identifier
|
||||
|
||||
Returns:
|
||||
RevocationResult with status
|
||||
"""
|
||||
try:
|
||||
# Get user ID from context if not provided
|
||||
if not user_id:
|
||||
user_id = mcp.context.get("user_id", "default_user")
|
||||
|
||||
# Check current status
|
||||
status = await get_provisioning_status(mcp, user_id)
|
||||
if not status.is_provisioned:
|
||||
return RevocationResult(
|
||||
success=True,
|
||||
message="No Nextcloud access to revoke.",
|
||||
)
|
||||
|
||||
# Initialize Token Broker to handle revocation
|
||||
storage = RefreshTokenStorage.from_env()
|
||||
await storage.initialize()
|
||||
|
||||
encryption_key = os.getenv("TOKEN_ENCRYPTION_KEY")
|
||||
if not encryption_key:
|
||||
return RevocationResult(
|
||||
success=False,
|
||||
message="Token encryption key not configured.",
|
||||
)
|
||||
|
||||
broker = TokenBrokerService(
|
||||
storage=storage,
|
||||
oidc_discovery_url=os.getenv(
|
||||
"OIDC_DISCOVERY_URL",
|
||||
f"{os.getenv('NEXTCLOUD_HOST')}/.well-known/openid-configuration",
|
||||
),
|
||||
nextcloud_host=os.getenv("NEXTCLOUD_HOST"),
|
||||
encryption_key=encryption_key,
|
||||
)
|
||||
|
||||
# Revoke access
|
||||
success = await broker.revoke_nextcloud_access(user_id)
|
||||
|
||||
if success:
|
||||
return RevocationResult(
|
||||
success=True,
|
||||
message=(
|
||||
"Successfully revoked Nextcloud access. "
|
||||
"You can run 'provision_nextcloud_access' again if needed."
|
||||
),
|
||||
)
|
||||
else:
|
||||
return RevocationResult(
|
||||
success=False,
|
||||
message="Failed to revoke access. Please try again.",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to revoke access: {e}")
|
||||
return RevocationResult(
|
||||
success=False,
|
||||
message=f"Failed to revoke access: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
async def check_provisioning_status(
|
||||
mcp: Context, user_id: Optional[str] = None
|
||||
) -> ProvisioningStatus:
|
||||
"""
|
||||
MCP Tool: Check the current provisioning status.
|
||||
|
||||
This tool allows users to check whether they have provisioned
|
||||
Nextcloud access and see details about their current authorization.
|
||||
|
||||
Args:
|
||||
mcp: MCP context
|
||||
user_id: Optional user identifier
|
||||
|
||||
Returns:
|
||||
ProvisioningStatus with current state
|
||||
"""
|
||||
# Get user ID from context if not provided
|
||||
if not user_id:
|
||||
user_id = mcp.context.get("user_id", "default_user")
|
||||
|
||||
return await get_provisioning_status(mcp, user_id)
|
||||
|
||||
|
||||
# Register MCP tools
|
||||
def register_oauth_tools(mcp):
|
||||
"""Register OAuth and provisioning tools with the MCP server."""
|
||||
|
||||
@mcp.tool(
|
||||
name="provision_nextcloud_access",
|
||||
description=(
|
||||
"Provision offline access to Nextcloud resources. "
|
||||
"This is required before using Nextcloud tools. "
|
||||
"You'll need to complete an OAuth authorization in your browser."
|
||||
),
|
||||
)
|
||||
async def tool_provision_access(
|
||||
user_id: Optional[str] = None,
|
||||
) -> ProvisioningResult:
|
||||
return await provision_nextcloud_access(mcp, user_id)
|
||||
|
||||
@mcp.tool(
|
||||
name="revoke_nextcloud_access",
|
||||
description="Revoke offline access to Nextcloud resources.",
|
||||
)
|
||||
async def tool_revoke_access(user_id: Optional[str] = None) -> RevocationResult:
|
||||
return await revoke_nextcloud_access(mcp, user_id)
|
||||
|
||||
@mcp.tool(
|
||||
name="check_provisioning_status",
|
||||
description="Check whether Nextcloud access is provisioned.",
|
||||
)
|
||||
async def tool_check_status(user_id: Optional[str] = None) -> ProvisioningStatus:
|
||||
return await check_provisioning_status(mcp, user_id)
|
||||
@@ -0,0 +1,353 @@
|
||||
"""
|
||||
Unit tests for Token Broker Service (ADR-004 Progressive Consent).
|
||||
|
||||
Tests the token management, caching, and refresh logic without
|
||||
requiring real network calls or database connections.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import jwt
|
||||
import pytest
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
from nextcloud_mcp_server.auth.token_broker import TokenBrokerService, TokenCache
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def encryption_key():
|
||||
"""Generate test encryption key."""
|
||||
return Fernet.generate_key().decode()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage():
|
||||
"""Mock RefreshTokenStorage."""
|
||||
storage = AsyncMock()
|
||||
storage.get_refresh_token = AsyncMock(return_value=None)
|
||||
storage.store_refresh_token = AsyncMock()
|
||||
storage.delete_refresh_token = AsyncMock()
|
||||
return storage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_oidc_config():
|
||||
"""Mock OIDC configuration."""
|
||||
return {
|
||||
"issuer": "https://idp.example.com",
|
||||
"token_endpoint": "https://idp.example.com/token",
|
||||
"revocation_endpoint": "https://idp.example.com/revoke",
|
||||
"jwks_uri": "https://idp.example.com/jwks",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def token_broker(mock_storage, encryption_key):
|
||||
"""Create TokenBrokerService instance."""
|
||||
broker = TokenBrokerService(
|
||||
storage=mock_storage,
|
||||
oidc_discovery_url="https://idp.example.com/.well-known/openid-configuration",
|
||||
nextcloud_host="https://nextcloud.example.com",
|
||||
encryption_key=encryption_key,
|
||||
cache_ttl=300,
|
||||
)
|
||||
yield broker
|
||||
await broker.close()
|
||||
|
||||
|
||||
class TestTokenCache:
|
||||
"""Test the TokenCache component."""
|
||||
|
||||
async def test_cache_stores_and_retrieves_token(self):
|
||||
"""Test basic cache storage and retrieval."""
|
||||
cache = TokenCache(ttl_seconds=60)
|
||||
|
||||
# Store token with sufficient expiry time (more than 30s threshold)
|
||||
await cache.set("user1", "test_token", expires_in=120)
|
||||
|
||||
# Retrieve token
|
||||
token = await cache.get("user1")
|
||||
assert token == "test_token"
|
||||
|
||||
async def test_cache_respects_ttl(self):
|
||||
"""Test that cache respects TTL."""
|
||||
# Create cache with 1 second TTL and 0 second early refresh
|
||||
cache = TokenCache(ttl_seconds=1, early_refresh_seconds=0)
|
||||
|
||||
# Store token
|
||||
await cache.set("user1", "test_token")
|
||||
|
||||
# Token should be available immediately
|
||||
assert await cache.get("user1") == "test_token"
|
||||
|
||||
# Wait for TTL to expire
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# Token should be expired
|
||||
assert await cache.get("user1") is None
|
||||
|
||||
async def test_cache_early_refresh(self):
|
||||
"""Test that cache returns None for tokens expiring soon."""
|
||||
cache = TokenCache(ttl_seconds=60)
|
||||
|
||||
# Store token that expires in 25 seconds (less than 30s threshold)
|
||||
await cache.set("user1", "test_token", expires_in=25)
|
||||
|
||||
# Should return None as it's expiring soon (within 30s)
|
||||
assert await cache.get("user1") is None
|
||||
|
||||
async def test_cache_invalidation(self):
|
||||
"""Test cache invalidation."""
|
||||
cache = TokenCache(ttl_seconds=60)
|
||||
|
||||
# Store and verify token
|
||||
await cache.set("user1", "test_token")
|
||||
assert await cache.get("user1") == "test_token"
|
||||
|
||||
# Invalidate
|
||||
await cache.invalidate("user1")
|
||||
|
||||
# Should be removed
|
||||
assert await cache.get("user1") is None
|
||||
|
||||
|
||||
class TestTokenBrokerService:
|
||||
"""Test the TokenBrokerService."""
|
||||
|
||||
async def test_has_nextcloud_provisioning(self, token_broker, mock_storage):
|
||||
"""Test checking if user has provisioned Nextcloud access."""
|
||||
# No provisioning
|
||||
mock_storage.get_refresh_token.return_value = None
|
||||
assert await token_broker.has_nextcloud_provisioning("user1") is False
|
||||
|
||||
# Has provisioning
|
||||
mock_storage.get_refresh_token.return_value = {
|
||||
"refresh_token": "encrypted_token",
|
||||
"expires_at": datetime.now(timezone.utc) + timedelta(days=30),
|
||||
}
|
||||
assert await token_broker.has_nextcloud_provisioning("user1") is True
|
||||
|
||||
async def test_get_nextcloud_token_from_cache(self, token_broker):
|
||||
"""Test getting token from cache."""
|
||||
# Pre-populate cache
|
||||
await token_broker.cache.set("user1", "cached_token", expires_in=300)
|
||||
|
||||
# Should return cached token without calling storage
|
||||
token = await token_broker.get_nextcloud_token("user1")
|
||||
assert token == "cached_token"
|
||||
token_broker.storage.get_refresh_token.assert_not_called()
|
||||
|
||||
async def test_get_nextcloud_token_refresh(
|
||||
self, token_broker, mock_storage, encryption_key, mock_oidc_config
|
||||
):
|
||||
"""Test getting token via refresh when not cached."""
|
||||
# Setup encrypted refresh token in storage
|
||||
fernet = Fernet(encryption_key.encode())
|
||||
encrypted_token = fernet.encrypt(b"test_refresh_token").decode()
|
||||
mock_storage.get_refresh_token.return_value = {
|
||||
"refresh_token": encrypted_token,
|
||||
"expires_at": datetime.now(timezone.utc) + timedelta(days=30),
|
||||
}
|
||||
|
||||
# Mock HTTP client for token refresh
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"access_token": "new_access_token",
|
||||
"expires_in": 3600,
|
||||
"token_type": "Bearer",
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
token_broker, "_get_oidc_config", return_value=mock_oidc_config
|
||||
):
|
||||
with patch.object(token_broker, "_get_http_client") as mock_client:
|
||||
mock_client.return_value.post = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Get token (should refresh)
|
||||
token = await token_broker.get_nextcloud_token("user1")
|
||||
|
||||
assert token == "new_access_token"
|
||||
# Verify token was cached
|
||||
cached = await token_broker.cache.get("user1")
|
||||
assert cached == "new_access_token"
|
||||
|
||||
async def test_get_nextcloud_token_no_provisioning(
|
||||
self, token_broker, mock_storage
|
||||
):
|
||||
"""Test getting token when user hasn't provisioned."""
|
||||
mock_storage.get_refresh_token.return_value = None
|
||||
|
||||
token = await token_broker.get_nextcloud_token("user1")
|
||||
assert token is None
|
||||
|
||||
async def test_refresh_master_token(
|
||||
self, token_broker, mock_storage, encryption_key, mock_oidc_config
|
||||
):
|
||||
"""Test master refresh token rotation."""
|
||||
# Setup current refresh token
|
||||
fernet = Fernet(encryption_key.encode())
|
||||
encrypted_token = fernet.encrypt(b"current_refresh_token").decode()
|
||||
mock_storage.get_refresh_token.return_value = {
|
||||
"refresh_token": encrypted_token,
|
||||
"expires_at": datetime.now(timezone.utc) + timedelta(days=30),
|
||||
}
|
||||
|
||||
# Mock successful refresh response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"access_token": "new_access",
|
||||
"refresh_token": "new_refresh_token",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
token_broker, "_get_oidc_config", return_value=mock_oidc_config
|
||||
):
|
||||
with patch.object(token_broker, "_get_http_client") as mock_client:
|
||||
mock_client.return_value.post = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Rotate token
|
||||
success = await token_broker.refresh_master_token("user1")
|
||||
|
||||
assert success is True
|
||||
# Verify new token was stored
|
||||
mock_storage.store_refresh_token.assert_called_once()
|
||||
call_args = mock_storage.store_refresh_token.call_args[1]
|
||||
assert call_args["user_id"] == "user1"
|
||||
# Decrypt to verify it's the new token
|
||||
stored_token = fernet.decrypt(
|
||||
call_args["refresh_token"].encode()
|
||||
).decode()
|
||||
assert stored_token == "new_refresh_token"
|
||||
|
||||
async def test_refresh_master_token_no_rotation(
|
||||
self, token_broker, mock_storage, encryption_key, mock_oidc_config
|
||||
):
|
||||
"""Test when IdP returns same refresh token (no rotation)."""
|
||||
# Setup current refresh token
|
||||
fernet = Fernet(encryption_key.encode())
|
||||
encrypted_token = fernet.encrypt(b"same_refresh_token").decode()
|
||||
mock_storage.get_refresh_token.return_value = {
|
||||
"refresh_token": encrypted_token,
|
||||
"expires_at": datetime.now(timezone.utc) + timedelta(days=30),
|
||||
}
|
||||
|
||||
# Mock response with same refresh token
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"access_token": "new_access",
|
||||
"refresh_token": "same_refresh_token",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
token_broker, "_get_oidc_config", return_value=mock_oidc_config
|
||||
):
|
||||
with patch.object(token_broker, "_get_http_client") as mock_client:
|
||||
mock_client.return_value.post = AsyncMock(return_value=mock_response)
|
||||
|
||||
success = await token_broker.refresh_master_token("user1")
|
||||
|
||||
assert success is True
|
||||
# Should not store if token didn't change
|
||||
mock_storage.store_refresh_token.assert_not_called()
|
||||
|
||||
async def test_revoke_nextcloud_access(
|
||||
self, token_broker, mock_storage, encryption_key, mock_oidc_config
|
||||
):
|
||||
"""Test revoking Nextcloud access."""
|
||||
# Setup refresh token for revocation
|
||||
fernet = Fernet(encryption_key.encode())
|
||||
encrypted_token = fernet.encrypt(b"token_to_revoke").decode()
|
||||
mock_storage.get_refresh_token.return_value = {
|
||||
"refresh_token": encrypted_token,
|
||||
"expires_at": datetime.now(timezone.utc) + timedelta(days=30),
|
||||
}
|
||||
|
||||
# Mock revocation response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
with patch.object(
|
||||
token_broker, "_get_oidc_config", return_value=mock_oidc_config
|
||||
):
|
||||
with patch.object(token_broker, "_get_http_client") as mock_client:
|
||||
mock_client.return_value.post = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Pre-populate cache
|
||||
await token_broker.cache.set("user1", "cached_token")
|
||||
|
||||
# Revoke access
|
||||
success = await token_broker.revoke_nextcloud_access("user1")
|
||||
|
||||
assert success is True
|
||||
# Verify token was deleted from storage
|
||||
mock_storage.delete_refresh_token.assert_called_once_with("user1")
|
||||
# Verify cache was cleared
|
||||
assert await token_broker.cache.get("user1") is None
|
||||
|
||||
async def test_validate_token_audience(self, token_broker):
|
||||
"""Test token audience validation."""
|
||||
# Create test token with audience
|
||||
test_payload = {
|
||||
"sub": "user1",
|
||||
"aud": ["nextcloud", "other-service"],
|
||||
"exp": datetime.now(timezone.utc) + timedelta(hours=1),
|
||||
}
|
||||
test_token = jwt.encode(test_payload, "secret", algorithm="HS256")
|
||||
|
||||
# Should not raise for correct audience
|
||||
await token_broker._validate_token_audience(test_token, "nextcloud")
|
||||
|
||||
# Should raise for wrong audience
|
||||
with pytest.raises(ValueError, match="doesn't include wrong-audience"):
|
||||
await token_broker._validate_token_audience(test_token, "wrong-audience")
|
||||
|
||||
async def test_token_refresh_with_network_error(
|
||||
self, token_broker, mock_storage, encryption_key
|
||||
):
|
||||
"""Test handling network errors during token refresh."""
|
||||
# Setup encrypted refresh token
|
||||
fernet = Fernet(encryption_key.encode())
|
||||
encrypted_token = fernet.encrypt(b"test_refresh_token").decode()
|
||||
mock_storage.get_refresh_token.return_value = {
|
||||
"refresh_token": encrypted_token,
|
||||
"expires_at": datetime.now(timezone.utc) + timedelta(days=30),
|
||||
}
|
||||
|
||||
# Mock network error
|
||||
with patch.object(token_broker, "_get_http_client") as mock_client:
|
||||
mock_client.return_value.post = AsyncMock(
|
||||
side_effect=httpx.NetworkError("Connection failed")
|
||||
)
|
||||
|
||||
# Should return None on error
|
||||
token = await token_broker.get_nextcloud_token("user1")
|
||||
assert token is None
|
||||
|
||||
# Cache should be invalidated
|
||||
assert await token_broker.cache.get("user1") is None
|
||||
|
||||
async def test_concurrent_cache_access(self, token_broker):
|
||||
"""Test concurrent access to token cache."""
|
||||
# Pre-populate cache
|
||||
await token_broker.cache.set("user1", "token1", expires_in=300)
|
||||
await token_broker.cache.set("user2", "token2", expires_in=300)
|
||||
|
||||
# Concurrent reads
|
||||
results = await asyncio.gather(
|
||||
token_broker.cache.get("user1"),
|
||||
token_broker.cache.get("user2"),
|
||||
token_broker.cache.get("user1"),
|
||||
token_broker.cache.get("user2"),
|
||||
)
|
||||
|
||||
assert results == ["token1", "token2", "token1", "token2"]
|
||||
Reference in New Issue
Block a user