feat: Implement custom PCA to remove sklearn dependency
- Add custom PCA implementation using numpy eigendecomposition - Replace sklearn.decomposition.PCA with custom implementation - Maintains same API (fit, transform, fit_transform) - Supports explained_variance_ratio_ for variance analysis - Removes scikit-learn dependency from project - Add type hints and assertion for type safety
This commit is contained in:
@@ -0,0 +1,581 @@
|
||||
"""Vector visualization routes for testing search algorithms.
|
||||
|
||||
Provides a web UI for users to test different search algorithms on their own
|
||||
indexed documents and visualize results in 2D space using PCA.
|
||||
|
||||
All processing happens server-side following ADR-012:
|
||||
- Search execution via shared search/algorithms.py
|
||||
- PCA dimensionality reduction (768-dim → 2D)
|
||||
- Only 2D coordinates + metadata sent to client
|
||||
- Bandwidth-efficient (2 floats per doc vs 768)
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
from starlette.authentication import requires
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import HTMLResponse, JSONResponse
|
||||
|
||||
from nextcloud_mcp_server.config import get_settings
|
||||
from nextcloud_mcp_server.search import (
|
||||
FuzzySearchAlgorithm,
|
||||
HybridSearchAlgorithm,
|
||||
KeywordSearchAlgorithm,
|
||||
SemanticSearchAlgorithm,
|
||||
)
|
||||
from nextcloud_mcp_server.vector.pca import PCA
|
||||
from nextcloud_mcp_server.vector.qdrant_client import get_qdrant_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@requires("authenticated", redirect="oauth_login")
|
||||
async def vector_visualization_html(request: Request) -> HTMLResponse:
|
||||
"""Vector visualization page with search controls and interactive plot.
|
||||
|
||||
Provides UI for testing search algorithms with real-time visualization.
|
||||
Requires vector sync to be enabled.
|
||||
|
||||
Args:
|
||||
request: Starlette request object
|
||||
|
||||
Returns:
|
||||
HTML page with search interface
|
||||
"""
|
||||
settings = get_settings()
|
||||
|
||||
if not settings.vector_sync_enabled:
|
||||
return HTMLResponse(
|
||||
"""
|
||||
<div>
|
||||
<h2>Vector Visualization</h2>
|
||||
<div style="padding: 20px; background: #fff3cd; border: 1px solid #ffc107; border-radius: 4px;">
|
||||
Vector sync is not enabled. Set VECTOR_SYNC_ENABLED=true to use this feature.
|
||||
</div>
|
||||
</div>
|
||||
"""
|
||||
)
|
||||
|
||||
# Get user info from session
|
||||
user_info = request.session.get("user_info", {})
|
||||
username = user_info.get("preferred_username", "unknown")
|
||||
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Vector Visualization - Nextcloud MCP</title>
|
||||
<script src="https://cdn.plot.ly/plotly-2.26.0.min.js"></script>
|
||||
<script src="https://unpkg.com/htmx.org@1.9.10"></script>
|
||||
<script src="https://unpkg.com/alpinejs@3.13.3/dist/cdn.min.js" defer></script>
|
||||
<style>
|
||||
body {{
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||
margin: 0;
|
||||
padding: 20px;
|
||||
background: #f5f5f5;
|
||||
}}
|
||||
.container {{
|
||||
max-width: 1400px;
|
||||
margin: 0 auto;
|
||||
}}
|
||||
.card {{
|
||||
background: white;
|
||||
border-radius: 8px;
|
||||
padding: 20px;
|
||||
margin-bottom: 20px;
|
||||
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
||||
}}
|
||||
.controls {{
|
||||
display: grid;
|
||||
grid-template-columns: 1fr 1fr;
|
||||
gap: 20px;
|
||||
margin-bottom: 20px;
|
||||
}}
|
||||
.control-group {{
|
||||
margin-bottom: 15px;
|
||||
}}
|
||||
label {{
|
||||
display: block;
|
||||
margin-bottom: 5px;
|
||||
font-weight: 500;
|
||||
color: #333;
|
||||
}}
|
||||
input[type="text"], select {{
|
||||
width: 100%;
|
||||
padding: 8px 12px;
|
||||
border: 1px solid #ddd;
|
||||
border-radius: 4px;
|
||||
font-size: 14px;
|
||||
}}
|
||||
input[type="range"] {{
|
||||
width: 100%;
|
||||
}}
|
||||
.weight-display {{
|
||||
display: inline-block;
|
||||
min-width: 40px;
|
||||
text-align: right;
|
||||
color: #666;
|
||||
}}
|
||||
.btn {{
|
||||
background: #0066cc;
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 10px 20px;
|
||||
border-radius: 4px;
|
||||
cursor: pointer;
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
}}
|
||||
.btn:hover {{
|
||||
background: #0052a3;
|
||||
}}
|
||||
#plot {{
|
||||
width: 100%;
|
||||
height: 600px;
|
||||
}}
|
||||
.loading {{
|
||||
text-align: center;
|
||||
padding: 40px;
|
||||
color: #666;
|
||||
}}
|
||||
.weight-controls {{
|
||||
display: none;
|
||||
}}
|
||||
.weight-controls.active {{
|
||||
display: block;
|
||||
}}
|
||||
.info-box {{
|
||||
background: #e3f2fd;
|
||||
border-left: 4px solid #2196f3;
|
||||
padding: 12px;
|
||||
margin-bottom: 20px;
|
||||
font-size: 14px;
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container" x-data="vizApp()">
|
||||
<div class="card">
|
||||
<h1>Vector Visualization</h1>
|
||||
<div class="info-box">
|
||||
Testing search algorithms on your indexed documents. User: <strong>{username}</strong>
|
||||
</div>
|
||||
|
||||
<form @submit.prevent="executeSearch">
|
||||
<div class="controls">
|
||||
<div>
|
||||
<div class="control-group">
|
||||
<label>Search Query</label>
|
||||
<input type="text" x-model="query" placeholder="Enter search query..." />
|
||||
</div>
|
||||
|
||||
<div class="control-group">
|
||||
<label>Search Algorithm</label>
|
||||
<select x-model="algorithm" @change="updateWeightControls">
|
||||
<option value="semantic">Semantic (Vector Similarity)</option>
|
||||
<option value="keyword">Keyword (Token Matching)</option>
|
||||
<option value="fuzzy">Fuzzy (Character Overlap)</option>
|
||||
<option value="hybrid" selected>Hybrid (RRF Fusion)</option>
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<div class="control-group weight-controls" :class="{{ active: algorithm === 'hybrid' }}">
|
||||
<label>Hybrid Weights</label>
|
||||
<div style="margin-bottom: 8px;">
|
||||
<label style="display: inline-block; width: 100px;">Semantic:</label>
|
||||
<input type="range" x-model.number="semanticWeight" min="0" max="1" step="0.1" style="width: 200px; display: inline-block;">
|
||||
<span class="weight-display" x-text="semanticWeight.toFixed(1)"></span>
|
||||
</div>
|
||||
<div style="margin-bottom: 8px;">
|
||||
<label style="display: inline-block; width: 100px;">Keyword:</label>
|
||||
<input type="range" x-model.number="keywordWeight" min="0" max="1" step="0.1" style="width: 200px; display: inline-block;">
|
||||
<span class="weight-display" x-text="keywordWeight.toFixed(1)"></span>
|
||||
</div>
|
||||
<div>
|
||||
<label style="display: inline-block; width: 100px;">Fuzzy:</label>
|
||||
<input type="range" x-model.number="fuzzyWeight" min="0" max="1" step="0.1" style="width: 200px; display: inline-block;">
|
||||
<span class="weight-display" x-text="fuzzyWeight.toFixed(1)"></span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<div class="control-group">
|
||||
<label>Result Limit</label>
|
||||
<input type="number" x-model.number="limit" min="1" max="100" value="50" />
|
||||
</div>
|
||||
|
||||
<div class="control-group">
|
||||
<label>Score Threshold (Semantic/Hybrid)</label>
|
||||
<input type="number" x-model.number="scoreThreshold" min="0" max="1" step="0.1" value="0.7" />
|
||||
</div>
|
||||
|
||||
<div class="control-group">
|
||||
<button type="submit" class="btn">Search & Visualize</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</form>
|
||||
</div>
|
||||
|
||||
<div class="card">
|
||||
<div x-show="loading" class="loading">
|
||||
Executing search and computing PCA projection...
|
||||
</div>
|
||||
<div id="plot" x-show="!loading"></div>
|
||||
</div>
|
||||
|
||||
<div class="card" x-show="results.length > 0">
|
||||
<h2>Search Results (<span x-text="results.length"></span>)</h2>
|
||||
<template x-for="result in results" :key="result.id">
|
||||
<div style="padding: 12px; border-bottom: 1px solid #eee;">
|
||||
<div style="font-weight: 500; color: #0066cc;" x-text="result.title"></div>
|
||||
<div style="font-size: 14px; color: #666; margin-top: 4px;" x-text="result.excerpt"></div>
|
||||
<div style="font-size: 12px; color: #999; margin-top: 4px;">
|
||||
Score: <span x-text="result.score.toFixed(3)"></span> |
|
||||
Type: <span x-text="result.doc_type"></span>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
function vizApp() {{
|
||||
return {{
|
||||
query: '',
|
||||
algorithm: 'hybrid',
|
||||
limit: 50,
|
||||
scoreThreshold: 0.7,
|
||||
semanticWeight: 0.5,
|
||||
keywordWeight: 0.3,
|
||||
fuzzyWeight: 0.2,
|
||||
loading: false,
|
||||
results: [],
|
||||
|
||||
updateWeightControls() {{
|
||||
// Update weight controls visibility based on algorithm
|
||||
}},
|
||||
|
||||
async executeSearch() {{
|
||||
this.loading = true;
|
||||
this.results = [];
|
||||
|
||||
try {{
|
||||
const params = new URLSearchParams({{
|
||||
query: this.query,
|
||||
algorithm: this.algorithm,
|
||||
limit: this.limit,
|
||||
score_threshold: this.scoreThreshold,
|
||||
semantic_weight: this.semanticWeight,
|
||||
keyword_weight: this.keywordWeight,
|
||||
fuzzy_weight: this.fuzzyWeight,
|
||||
}});
|
||||
|
||||
const response = await fetch(`/app/vector-viz/search?${{params}}`);
|
||||
const data = await response.json();
|
||||
|
||||
if (data.success) {{
|
||||
this.results = data.results;
|
||||
this.renderPlot(data.coordinates_2d, data.results);
|
||||
}} else {{
|
||||
alert('Search failed: ' + data.error);
|
||||
}}
|
||||
}} catch (error) {{
|
||||
alert('Error: ' + error.message);
|
||||
}} finally {{
|
||||
this.loading = false;
|
||||
}}
|
||||
}},
|
||||
|
||||
renderPlot(coordinates, results) {{
|
||||
const trace = {{
|
||||
x: coordinates.map(c => c[0]),
|
||||
y: coordinates.map(c => c[1]),
|
||||
mode: 'markers',
|
||||
type: 'scatter',
|
||||
text: results.map(r => `${{r.title}}<br>Score: ${{r.score.toFixed(3)}}`),
|
||||
marker: {{
|
||||
size: 8,
|
||||
color: results.map(r => r.score),
|
||||
colorscale: 'Viridis',
|
||||
showscale: true,
|
||||
colorbar: {{ title: 'Score' }}
|
||||
}}
|
||||
}};
|
||||
|
||||
const layout = {{
|
||||
title: `Vector Space (PCA 2D) - ${{results.length}} results`,
|
||||
xaxis: {{ title: 'PC1' }},
|
||||
yaxis: {{ title: 'PC2' }},
|
||||
hovermode: 'closest',
|
||||
height: 600
|
||||
}};
|
||||
|
||||
Plotly.newPlot('plot', [trace], layout);
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
return HTMLResponse(content=html_content)
|
||||
|
||||
|
||||
@requires("authenticated", redirect="oauth_login")
|
||||
async def vector_visualization_search(request: Request) -> JSONResponse:
|
||||
"""Execute server-side search and return 2D coordinates + results.
|
||||
|
||||
All processing happens server-side:
|
||||
1. Execute search via shared algorithm module
|
||||
2. Fetch matching vectors from Qdrant
|
||||
3. Apply PCA reduction (768-dim → 2D)
|
||||
4. Return coordinates + metadata only
|
||||
|
||||
Args:
|
||||
request: Starlette request with query parameters
|
||||
|
||||
Returns:
|
||||
JSON response with coordinates_2d and results
|
||||
"""
|
||||
settings = get_settings()
|
||||
|
||||
if not settings.vector_sync_enabled:
|
||||
return JSONResponse(
|
||||
{"success": False, "error": "Vector sync not enabled"},
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
# Get user info
|
||||
user_info = request.session.get("user_info", {})
|
||||
username = user_info.get("preferred_username")
|
||||
|
||||
if not username:
|
||||
return JSONResponse(
|
||||
{"success": False, "error": "User not authenticated"},
|
||||
status_code=401,
|
||||
)
|
||||
|
||||
# Parse query parameters
|
||||
query = request.query_params.get("query", "")
|
||||
algorithm = request.query_params.get("algorithm", "hybrid")
|
||||
limit = int(request.query_params.get("limit", "50"))
|
||||
score_threshold = float(request.query_params.get("score_threshold", "0.7"))
|
||||
semantic_weight = float(request.query_params.get("semantic_weight", "0.5"))
|
||||
keyword_weight = float(request.query_params.get("keyword_weight", "0.3"))
|
||||
fuzzy_weight = float(request.query_params.get("fuzzy_weight", "0.2"))
|
||||
|
||||
logger.info(
|
||||
f"Viz search: user={username}, query='{query}', "
|
||||
f"algorithm={algorithm}, limit={limit}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Get authenticated HTTP client from session
|
||||
# In BasicAuth mode: uses username/password from session
|
||||
# In OAuth mode: uses access token from session
|
||||
from nextcloud_mcp_server.auth.userinfo_routes import (
|
||||
_get_authenticated_client_for_userinfo,
|
||||
)
|
||||
from nextcloud_mcp_server.client.notes import NotesClient
|
||||
|
||||
async with await _get_authenticated_client_for_userinfo(request) as http_client:
|
||||
# Create NotesClient directly with authenticated HTTP client
|
||||
notes_client = NotesClient(http_client, username)
|
||||
|
||||
# Wrap in a minimal client object for search algorithms
|
||||
# This conforms to NextcloudClientProtocol but only implements notes
|
||||
class MinimalNextcloudClient:
|
||||
def __init__(self, notes_client, username):
|
||||
self._notes = notes_client
|
||||
self.username = username
|
||||
|
||||
@property
|
||||
def notes(self):
|
||||
return self._notes
|
||||
|
||||
@property
|
||||
def webdav(self):
|
||||
return None
|
||||
|
||||
@property
|
||||
def calendar(self):
|
||||
return None
|
||||
|
||||
@property
|
||||
def contacts(self):
|
||||
return None
|
||||
|
||||
@property
|
||||
def deck(self):
|
||||
return None
|
||||
|
||||
@property
|
||||
def cookbook(self):
|
||||
return None
|
||||
|
||||
@property
|
||||
def tables(self):
|
||||
return None
|
||||
|
||||
nextcloud_client = MinimalNextcloudClient(notes_client, username)
|
||||
|
||||
# Create search algorithm
|
||||
if algorithm == "semantic":
|
||||
search_algo = SemanticSearchAlgorithm(score_threshold=score_threshold)
|
||||
elif algorithm == "keyword":
|
||||
search_algo = KeywordSearchAlgorithm()
|
||||
elif algorithm == "fuzzy":
|
||||
search_algo = FuzzySearchAlgorithm()
|
||||
elif algorithm == "hybrid":
|
||||
search_algo = HybridSearchAlgorithm(
|
||||
semantic_weight=semantic_weight,
|
||||
keyword_weight=keyword_weight,
|
||||
fuzzy_weight=fuzzy_weight,
|
||||
)
|
||||
else:
|
||||
return JSONResponse(
|
||||
{"success": False, "error": f"Unknown algorithm: {algorithm}"},
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
# Execute search
|
||||
search_results = await search_algo.search(
|
||||
query=query,
|
||||
user_id=username,
|
||||
limit=limit,
|
||||
doc_type="note",
|
||||
nextcloud_client=nextcloud_client,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
|
||||
if not search_results:
|
||||
return JSONResponse(
|
||||
{
|
||||
"success": True,
|
||||
"results": [],
|
||||
"coordinates_2d": [],
|
||||
"message": "No results found",
|
||||
}
|
||||
)
|
||||
|
||||
# Fetch vectors for matching results from Qdrant
|
||||
qdrant_client = await get_qdrant_client()
|
||||
doc_ids = [r.id for r in search_results]
|
||||
|
||||
# Retrieve vectors for the matching documents
|
||||
from qdrant_client.models import FieldCondition, Filter, MatchAny
|
||||
|
||||
points_response = await qdrant_client.scroll(
|
||||
collection_name=settings.get_collection_name(),
|
||||
scroll_filter=Filter(
|
||||
must=[
|
||||
FieldCondition(
|
||||
key="doc_id",
|
||||
match=MatchAny(any=[str(doc_id) for doc_id in doc_ids]),
|
||||
),
|
||||
FieldCondition(
|
||||
key="user_id",
|
||||
match={"value": username},
|
||||
),
|
||||
]
|
||||
),
|
||||
limit=len(doc_ids) * 2, # Account for multiple chunks per doc
|
||||
with_vectors=True,
|
||||
with_payload=False,
|
||||
)
|
||||
|
||||
points = points_response[0]
|
||||
|
||||
if not points:
|
||||
return JSONResponse(
|
||||
{
|
||||
"success": True,
|
||||
"results": [],
|
||||
"coordinates_2d": [],
|
||||
"message": "No vectors found for results",
|
||||
}
|
||||
)
|
||||
|
||||
# Extract vectors
|
||||
vectors = np.array([p.vector for p in points if p.vector is not None])
|
||||
|
||||
if len(vectors) < 2:
|
||||
# Not enough points for PCA
|
||||
return JSONResponse(
|
||||
{
|
||||
"success": True,
|
||||
"results": [
|
||||
{
|
||||
"id": r.id,
|
||||
"doc_type": r.doc_type,
|
||||
"title": r.title,
|
||||
"excerpt": r.excerpt,
|
||||
"score": r.score,
|
||||
}
|
||||
for r in search_results
|
||||
],
|
||||
"coordinates_2d": [[0, 0]] * len(search_results),
|
||||
"message": "Not enough vectors for PCA",
|
||||
}
|
||||
)
|
||||
|
||||
# Apply PCA dimensionality reduction (768-dim → 2D)
|
||||
pca = PCA(n_components=2)
|
||||
coords_2d = pca.fit_transform(vectors)
|
||||
|
||||
# After fit, these attributes are guaranteed to be set
|
||||
assert pca.explained_variance_ratio_ is not None
|
||||
|
||||
logger.info(
|
||||
f"PCA explained variance: PC1={pca.explained_variance_ratio_[0]:.3f}, "
|
||||
f"PC2={pca.explained_variance_ratio_[1]:.3f}"
|
||||
)
|
||||
|
||||
# Map results to coordinates (use first chunk per document)
|
||||
result_coords = []
|
||||
seen_doc_ids = set()
|
||||
|
||||
for point, coord in zip(points, coords_2d):
|
||||
if point.payload:
|
||||
doc_id = int(point.payload.get("doc_id", 0))
|
||||
if doc_id not in seen_doc_ids and doc_id in doc_ids:
|
||||
seen_doc_ids.add(doc_id)
|
||||
result_coords.append(coord.tolist())
|
||||
|
||||
# Build response
|
||||
response_results = [
|
||||
{
|
||||
"id": r.id,
|
||||
"doc_type": r.doc_type,
|
||||
"title": r.title,
|
||||
"excerpt": r.excerpt,
|
||||
"score": r.score,
|
||||
}
|
||||
for r in search_results
|
||||
]
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
"success": True,
|
||||
"results": response_results,
|
||||
"coordinates_2d": result_coords[: len(search_results)],
|
||||
"pca_variance": {
|
||||
"pc1": float(pca.explained_variance_ratio_[0]),
|
||||
"pc2": float(pca.explained_variance_ratio_[1]),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Viz search error: {e}", exc_info=True)
|
||||
return JSONResponse(
|
||||
{"success": False, "error": str(e)},
|
||||
status_code=500,
|
||||
)
|
||||
@@ -0,0 +1,140 @@
|
||||
"""Custom PCA implementation for dimensionality reduction.
|
||||
|
||||
Implements Principal Component Analysis without scikit-learn dependency.
|
||||
Used for reducing high-dimensional embeddings (768-dim) to 2D for visualization.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PCA:
|
||||
"""Principal Component Analysis for dimensionality reduction.
|
||||
|
||||
Simple implementation that finds principal components via eigendecomposition
|
||||
of the covariance matrix. Suitable for small-to-medium datasets.
|
||||
|
||||
Attributes:
|
||||
n_components: Number of principal components to keep
|
||||
mean_: Mean of training data (set during fit)
|
||||
components_: Principal components (eigenvectors)
|
||||
explained_variance_: Variance explained by each component
|
||||
explained_variance_ratio_: Fraction of total variance explained
|
||||
"""
|
||||
|
||||
def __init__(self, n_components: int = 2):
|
||||
"""Initialize PCA.
|
||||
|
||||
Args:
|
||||
n_components: Number of components to keep (default: 2)
|
||||
"""
|
||||
if n_components < 1:
|
||||
raise ValueError(f"n_components must be >= 1, got {n_components}")
|
||||
|
||||
self.n_components = n_components
|
||||
self.mean_: np.ndarray | None = None
|
||||
self.components_: np.ndarray | None = None
|
||||
self.explained_variance_: np.ndarray | None = None
|
||||
self.explained_variance_ratio_: np.ndarray | None = None
|
||||
|
||||
def fit(self, X: np.ndarray) -> "PCA":
|
||||
"""Fit PCA model to data.
|
||||
|
||||
Args:
|
||||
X: Training data of shape (n_samples, n_features)
|
||||
|
||||
Returns:
|
||||
self (for method chaining)
|
||||
|
||||
Raises:
|
||||
ValueError: If X has fewer features than n_components
|
||||
"""
|
||||
X = np.asarray(X)
|
||||
|
||||
if X.ndim != 2:
|
||||
raise ValueError(f"X must be 2D array, got shape {X.shape}")
|
||||
|
||||
n_samples, n_features = X.shape
|
||||
|
||||
if n_features < self.n_components:
|
||||
raise ValueError(
|
||||
f"n_components={self.n_components} > n_features={n_features}"
|
||||
)
|
||||
|
||||
# Center data
|
||||
self.mean_ = np.mean(X, axis=0)
|
||||
X_centered = X - self.mean_
|
||||
|
||||
# Compute covariance matrix
|
||||
# Use (X^T X) / (n-1) for numerical stability with high-dim data
|
||||
cov = np.cov(X_centered.T)
|
||||
|
||||
# Eigendecomposition
|
||||
eigenvalues, eigenvectors = np.linalg.eigh(cov)
|
||||
|
||||
# Sort by eigenvalue (descending)
|
||||
idx = np.argsort(eigenvalues)[::-1]
|
||||
eigenvalues = eigenvalues[idx]
|
||||
eigenvectors = eigenvectors[:, idx]
|
||||
|
||||
# Keep top n_components
|
||||
self.components_ = eigenvectors[:, : self.n_components].T
|
||||
self.explained_variance_ = eigenvalues[: self.n_components]
|
||||
|
||||
# Calculate explained variance ratio
|
||||
total_variance = np.sum(eigenvalues)
|
||||
if total_variance > 0:
|
||||
self.explained_variance_ratio_ = self.explained_variance_ / total_variance
|
||||
else:
|
||||
self.explained_variance_ratio_ = np.zeros(self.n_components)
|
||||
|
||||
logger.debug(
|
||||
f"PCA fit: {n_samples} samples, {n_features} features → "
|
||||
f"{self.n_components} components, "
|
||||
f"explained variance: {self.explained_variance_ratio_}"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def transform(self, X: np.ndarray) -> np.ndarray:
|
||||
"""Transform data to principal component space.
|
||||
|
||||
Args:
|
||||
X: Data to transform of shape (n_samples, n_features)
|
||||
|
||||
Returns:
|
||||
Transformed data of shape (n_samples, n_components)
|
||||
|
||||
Raises:
|
||||
ValueError: If PCA not fitted yet
|
||||
"""
|
||||
if self.mean_ is None or self.components_ is None:
|
||||
raise ValueError("PCA not fitted yet. Call fit() first.")
|
||||
|
||||
X = np.asarray(X)
|
||||
|
||||
if X.ndim != 2:
|
||||
raise ValueError(f"X must be 2D array, got shape {X.shape}")
|
||||
|
||||
# Center using training mean
|
||||
X_centered = X - self.mean_
|
||||
|
||||
# Project onto principal components
|
||||
X_transformed = np.dot(X_centered, self.components_.T)
|
||||
|
||||
return X_transformed
|
||||
|
||||
def fit_transform(self, X: np.ndarray) -> np.ndarray:
|
||||
"""Fit PCA model and transform data in one step.
|
||||
|
||||
Args:
|
||||
X: Training data of shape (n_samples, n_features)
|
||||
|
||||
Returns:
|
||||
Transformed data of shape (n_samples, n_components)
|
||||
"""
|
||||
self.fit(X)
|
||||
return self.transform(X)
|
||||
Reference in New Issue
Block a user