Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/embeddings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,11 @@ pub async fn generate_embeddings(
const CONCURRENCY_LIMIT: usize = 8; // Number of concurrent requests
const TOKEN_LIMIT: usize = 8000; // Keep a buffer below the 8192 limit

let results = stream::iter(documents.iter().enumerate())
let results = stream::iter(documents.iter().enumerate().map(|(i, d)| (i, d.clone())).collect::<Vec<_>>())
.map(|(index, doc)| {
// Clone client, model, doc, and Arc<BPE> for the async block
let client = client.clone();
let model = model.to_string();
let doc = doc.clone();
let bpe = Arc::clone(&bpe); // Clone the Arc pointer

async move {
Expand Down
108 changes: 76 additions & 32 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use crate::{
doc_loader::Document,
embeddings::{generate_embeddings, CachedDocumentEmbedding, OPENAI_CLIENT},
error::ServerError,
server::RustDocsServer, // Import the updated RustDocsServer
};
use async_openai::{Client as OpenAIClient, config::OpenAIConfig};
use bincode::config;
Expand Down Expand Up @@ -38,10 +37,12 @@ use xdg::BaseDirectories;
#[command(author, version, about, long_about = None)]
struct Cli {
/// The package ID specification (e.g., "serde@^1.0", "tokio").
/// If not provided, server runs in "any crate mode" allowing queries for any crate.
#[arg()] // Positional argument
package_spec: String,
package_spec: Option<String>,

/// Optional features to enable for the crate when generating documentation.
/// Only used in single-crate mode.
#[arg(short = 'F', long, value_delimiter = ',', num_args = 0..)] // Allow multiple comma-separated values
features: Option<Vec<String>>,
}
Expand All @@ -67,9 +68,39 @@ async fn main() -> Result<(), ServerError> {

// --- Parse CLI Arguments ---
let cli = Cli::parse();
let specid_str = cli.package_spec.trim().to_string(); // Trim whitespace
let features = cli.features.map(|f| {
f.into_iter().map(|s| s.trim().to_string()).collect() // Trim each feature

// Initialize OpenAI Client early (needed for both modes)
let openai_client = if let Ok(api_base) = env::var("OPENAI_API_BASE") {
let config = OpenAIConfig::new().with_api_base(api_base);
OpenAIClient::with_config(config)
} else {
OpenAIClient::new()
};
OPENAI_CLIENT
.set(openai_client.clone())
.expect("Failed to set OpenAI client");

// Determine which mode to run in
match cli.package_spec {
Some(package_spec) => {
// Single-crate mode
run_single_crate_mode(package_spec, cli.features).await
}
None => {
// Any-crate mode
run_any_crate_mode().await
}
}
}

/// Run the server in single-crate mode with pre-loaded documentation
async fn run_single_crate_mode(
package_spec: String,
features: Option<Vec<String>>,
) -> Result<(), ServerError> {
let specid_str = package_spec.trim().to_string();
let features = features.map(|f| {
f.into_iter().map(|s| s.trim().to_string()).collect()
});

// Parse the specid string
Expand All @@ -87,23 +118,19 @@ async fn main() -> Result<(), ServerError> {
.unwrap_or_else(|| "*".to_string());

eprintln!(
"Target Spec: {}, Parsed Name: {}, Version Req: {}, Features: {:?}",
"Single-crate mode - Target Spec: {}, Parsed Name: {}, Version Req: {}, Features: {:?}",
specid_str, crate_name, crate_version_req, features
);

// --- Determine Paths (incorporating features) ---

// Sanitize the version requirement string
let sanitized_version_req = crate_version_req
.replace(|c: char| !c.is_alphanumeric() && c != '.' && c != '-', "_");

// Generate a stable hash for the features to use in the path
let features_hash = hash_features(&features);

// Construct the relative path component including features hash
let embeddings_relative_path = PathBuf::from(&crate_name)
.join(&sanitized_version_req)
.join(&features_hash) // Add features hash as a directory level
.join(&features_hash)
.join("embeddings.bin");

#[cfg(not(target_os = "windows"))]
Expand All @@ -121,7 +148,6 @@ async fn main() -> Result<(), ServerError> {
ServerError::Config("Could not determine cache directory on Windows".to_string())
})?;
let app_cache_dir = cache_dir.join("rustdocs-mcp-server");
// Ensure the base app cache directory exists
fs::create_dir_all(&app_cache_dir).map_err(ServerError::Io)?;
app_cache_dir.join(embeddings_relative_path)
};
Expand Down Expand Up @@ -181,16 +207,7 @@ async fn main() -> Result<(), ServerError> {
let mut generation_cost: Option<f64> = None;
let mut documents_for_server: Vec<Document> = loaded_documents_from_cache.unwrap_or_default();

// --- Initialize OpenAI Client (needed for question embedding even if cache hit) ---
let openai_client = if let Ok(api_base) = env::var("OPENAI_API_BASE") {
let config = OpenAIConfig::new().with_api_base(api_base);
OpenAIClient::with_config(config)
} else {
OpenAIClient::new()
};
OPENAI_CLIENT
.set(openai_client.clone()) // Clone the client for the OnceCell
.expect("Failed to set OpenAI client");
let openai_client = OPENAI_CLIENT.get().unwrap();

let final_embeddings = match loaded_embeddings {
Some(embeddings) => {
Expand All @@ -207,9 +224,8 @@ async fn main() -> Result<(), ServerError> {
"Loading documents for crate: {} (Version Req: {}, Features: {:?})",
crate_name, crate_version_req, features
);
// Pass features to load_documents
let loaded_documents =
doc_loader::load_documents(&crate_name, &crate_version_req, features.as_ref())?; // Pass features here
doc_loader::load_documents(&crate_name, &crate_version_req, features.as_ref())?;
eprintln!("Loaded {} documents.", loaded_documents.len());
documents_for_server = loaded_documents.clone();

Expand Down Expand Up @@ -312,29 +328,57 @@ async fn main() -> Result<(), ServerError> {
)
};

// Create the service instance using the updated ::new()
let service = RustDocsServer::new(
crate_name.clone(), // Pass crate_name directly
// Create the service instance for single-crate mode
let service = server::RustDocsSingleCrateServer::new(
crate_name.clone(),
documents_for_server,
final_embeddings,
startup_message,
)?;

// --- Use standard stdio transport and ServiceExt ---
eprintln!("Rust Docs MCP server starting via stdio...");
eprintln!("Rust Docs MCP server starting via stdio (single-crate mode)...");

// Serve the server using the ServiceExt trait and standard stdio transport
let server_handle = service.serve(stdio()).await.map_err(|e| {
eprintln!("Failed to start server: {:?}", e);
ServerError::McpRuntime(e.to_string()) // Use the new McpRuntime variant
ServerError::McpRuntime(e.to_string())
})?;

eprintln!("{} Docs MCP server running...", &crate_name);

// Wait for the server to complete (e.g., stdin closed)
server_handle.waiting().await.map_err(|e| {
eprintln!("Server encountered an error while running: {:?}", e);
ServerError::McpRuntime(e.to_string()) // Use the new McpRuntime variant
ServerError::McpRuntime(e.to_string())
})?;

eprintln!("Rust Docs MCP server stopped.");
Ok(())
}

/// Run the server in any-crate mode where documentation is loaded on-demand
async fn run_any_crate_mode() -> Result<(), ServerError> {
eprintln!("Any-crate mode - Server will load documentation on demand for any requested crate");

// Verify OpenAI API key is available
let _openai_api_key = env::var("OPENAI_API_KEY")
.map_err(|_| ServerError::MissingEnvVar("OPENAI_API_KEY".to_string()))?;

// Create the service instance for any-crate mode
let service = server::RustDocsAnyCrateServer::new()?;

// --- Use standard stdio transport and ServiceExt ---
eprintln!("Rust Docs MCP server starting via stdio (any-crate mode)...");

let server_handle = service.serve(stdio()).await.map_err(|e| {
eprintln!("Failed to start server: {:?}", e);
ServerError::McpRuntime(e.to_string())
})?;

eprintln!("Rust Docs MCP server running (any-crate mode)...");

server_handle.waiting().await.map_err(|e| {
eprintln!("Server encountered an error while running: {:?}", e);
ServerError::McpRuntime(e.to_string())
})?;

eprintln!("Rust Docs MCP server stopped.");
Expand Down
Loading