518 lines
15 KiB
Rust
518 lines
15 KiB
Rust
|
|
//! REST API for RAG System (Phase 8)
|
||
|
|
//!
|
||
|
|
//! Provides comprehensive HTTP endpoints for all Phase 7 features including:
|
||
|
|
//! - Query processing with optional streaming
|
||
|
|
//! - Hybrid search configuration
|
||
|
|
//! - Conversation management
|
||
|
|
//! - Batch processing
|
||
|
|
//! - Tool execution
|
||
|
|
//! - Cache management
|
||
|
|
//!
|
||
|
|
//! # Features
|
||
|
|
//!
|
||
|
|
//! - Async request handling with Axum
|
||
|
|
//! - JSON request/response validation
|
||
|
|
//! - Error handling with proper HTTP status codes
|
||
|
|
//! - OpenAPI/Swagger compatible types
|
||
|
|
//! - CORS support for cross-origin requests
|
||
|
|
|
||
|
|
use std::sync::Arc;
|
||
|
|
|
||
|
|
use axum::{
|
||
|
|
extract::{Json, Path, State},
|
||
|
|
http::StatusCode,
|
||
|
|
response::{IntoResponse, Response},
|
||
|
|
routing::{get, post},
|
||
|
|
Router,
|
||
|
|
};
|
||
|
|
use serde::{Deserialize, Serialize};
|
||
|
|
use tokio::sync::RwLock;
|
||
|
|
|
||
|
|
use crate::{
|
||
|
|
agent::{AgentResponse, RagAgent},
|
||
|
|
batch_processing::{BatchAgent, BatchJob, BatchQuery},
|
||
|
|
caching::ResponseCache,
|
||
|
|
conversations::ConversationAgent,
|
||
|
|
error::{RagError, Result},
|
||
|
|
query_optimization::QueryOptimizer,
|
||
|
|
retrieval::SearchResult,
|
||
|
|
};
|
||
|
|
|
||
|
|
/// API server state containing core RAG components
|
||
|
|
#[derive(Clone)]
|
||
|
|
pub struct ApiState {
|
||
|
|
pub agent: Arc<RagAgent>,
|
||
|
|
pub cache: Arc<ResponseCache>,
|
||
|
|
pub optimizer: Arc<QueryOptimizer>,
|
||
|
|
pub conversation: Arc<RwLock<ConversationAgent>>,
|
||
|
|
pub batch_agent: Arc<BatchAgent>,
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Query request with optional context
|
||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
|
|
pub struct QueryRequest {
|
||
|
|
/// The user's question
|
||
|
|
pub query: String,
|
||
|
|
/// Optional conversation context for follow-up questions
|
||
|
|
pub conversation_context: Option<String>,
|
||
|
|
/// Whether to use hybrid search (default: true)
|
||
|
|
#[serde(default = "default_hybrid_search")]
|
||
|
|
pub use_hybrid_search: bool,
|
||
|
|
/// Number of results to retrieve (default: 5)
|
||
|
|
#[serde(default = "default_num_results")]
|
||
|
|
pub num_results: usize,
|
||
|
|
}
|
||
|
|
|
||
|
|
fn default_hybrid_search() -> bool {
|
||
|
|
true
|
||
|
|
}
|
||
|
|
|
||
|
|
fn default_num_results() -> usize {
|
||
|
|
5
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Query response with metadata
|
||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
|
|
pub struct QueryResponse {
|
||
|
|
/// The generated answer
|
||
|
|
pub answer: String,
|
||
|
|
/// Retrieved source documents
|
||
|
|
pub sources: Vec<SearchResult>,
|
||
|
|
/// Confidence score (0.0-1.0)
|
||
|
|
pub confidence: f32,
|
||
|
|
/// Processing context
|
||
|
|
pub context: String,
|
||
|
|
/// Whether result was from cache
|
||
|
|
#[serde(skip_serializing_if = "std::option::Option::is_none")]
|
||
|
|
pub from_cache: Option<bool>,
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Batch query for parallel processing
|
||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
|
|
pub struct BatchQueryRequest {
|
||
|
|
/// Individual queries in the batch
|
||
|
|
pub queries: Vec<String>,
|
||
|
|
/// Maximum concurrent queries (default: 5)
|
||
|
|
#[serde(default = "default_max_concurrent")]
|
||
|
|
pub max_concurrent: usize,
|
||
|
|
/// Timeout per query in seconds (default: 30)
|
||
|
|
#[serde(default = "default_timeout_secs")]
|
||
|
|
pub timeout_secs: u64,
|
||
|
|
}
|
||
|
|
|
||
|
|
fn default_max_concurrent() -> usize {
|
||
|
|
5
|
||
|
|
}
|
||
|
|
|
||
|
|
fn default_timeout_secs() -> u64 {
|
||
|
|
30
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Batch processing response
|
||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
|
|
pub struct BatchResponse {
|
||
|
|
/// Job ID for tracking
|
||
|
|
pub job_id: String,
|
||
|
|
/// Individual query results
|
||
|
|
pub results: Vec<QueryResponse>,
|
||
|
|
/// Batch statistics
|
||
|
|
pub stats: BatchStats,
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Batch processing statistics
|
||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
|
|
pub struct BatchStats {
|
||
|
|
/// Total queries in batch
|
||
|
|
pub total_queries: usize,
|
||
|
|
/// Successfully processed queries
|
||
|
|
pub successful_queries: usize,
|
||
|
|
/// Failed queries
|
||
|
|
pub failed_queries: usize,
|
||
|
|
/// Success rate as percentage
|
||
|
|
pub success_rate: f32,
|
||
|
|
/// Total processing time in milliseconds
|
||
|
|
pub total_duration_ms: u64,
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Conversation turn request
|
||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
|
|
pub struct ConversationRequest {
|
||
|
|
/// The user's message
|
||
|
|
pub message: String,
|
||
|
|
/// Conversation ID for tracking
|
||
|
|
pub conversation_id: String,
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Conversation turn response
|
||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
|
|
pub struct ConversationResponse {
|
||
|
|
/// Agent's response
|
||
|
|
pub response: String,
|
||
|
|
/// Updated conversation context
|
||
|
|
pub context: String,
|
||
|
|
/// Follow-up suggestions
|
||
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||
|
|
pub followup_suggestions: Option<Vec<String>>,
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Cache statistics
|
||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
|
|
pub struct CacheStatsResponse {
|
||
|
|
/// Total items in cache
|
||
|
|
pub items_in_cache: usize,
|
||
|
|
/// Cache hit count
|
||
|
|
pub hits: u64,
|
||
|
|
/// Cache miss count
|
||
|
|
pub misses: u64,
|
||
|
|
/// Hit rate as percentage
|
||
|
|
pub hit_rate: f32,
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Error response format
|
||
|
|
#[derive(Debug, Serialize)]
|
||
|
|
pub struct ErrorResponse {
|
||
|
|
/// Error message
|
||
|
|
pub error: String,
|
||
|
|
/// Error code
|
||
|
|
pub code: String,
|
||
|
|
/// HTTP status code
|
||
|
|
pub status: u16,
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Health check response
|
||
|
|
#[derive(Debug, Serialize)]
|
||
|
|
pub struct HealthResponse {
|
||
|
|
/// Service status
|
||
|
|
pub status: String,
|
||
|
|
/// Version information
|
||
|
|
pub version: String,
|
||
|
|
/// Component health
|
||
|
|
pub components: ComponentHealth,
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Component health status
|
||
|
|
#[derive(Debug, Serialize)]
|
||
|
|
pub struct ComponentHealth {
|
||
|
|
/// RAG agent status
|
||
|
|
pub agent: String,
|
||
|
|
/// Cache status
|
||
|
|
pub cache: String,
|
||
|
|
/// Database connection status
|
||
|
|
pub database: String,
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Create the REST API router
|
||
|
|
pub async fn create_router(state: ApiState) -> Router {
|
||
|
|
Router::new()
|
||
|
|
// Health and info endpoints
|
||
|
|
.route("/health", get(health_check))
|
||
|
|
.route("/info", get(api_info))
|
||
|
|
// Query endpoints
|
||
|
|
.route("/query", post(query_handler))
|
||
|
|
.route("/query/stream", post(stream_handler))
|
||
|
|
// Batch processing endpoints
|
||
|
|
.route("/batch", post(batch_handler))
|
||
|
|
.route("/batch/:job_id", get(batch_status_handler))
|
||
|
|
// Conversation endpoints
|
||
|
|
.route("/conversation", post(conversation_handler))
|
||
|
|
.route(
|
||
|
|
"/conversation/:conversation_id",
|
||
|
|
get(conversation_history_handler),
|
||
|
|
)
|
||
|
|
// Cache management endpoints
|
||
|
|
.route("/cache/stats", get(cache_stats_handler))
|
||
|
|
.route("/cache/clear", post(cache_clear_handler))
|
||
|
|
// Tool execution endpoints
|
||
|
|
.route("/tools", get(list_tools_handler))
|
||
|
|
.route("/tools/:tool_id/execute", post(execute_tool_handler))
|
||
|
|
.with_state(state)
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Health check endpoint
|
||
|
|
async fn health_check(State(_state): State<ApiState>) -> impl IntoResponse {
|
||
|
|
Json(HealthResponse {
|
||
|
|
status: "healthy".to_string(),
|
||
|
|
version: env!("CARGO_PKG_VERSION").to_string(),
|
||
|
|
components: ComponentHealth {
|
||
|
|
agent: "operational".to_string(),
|
||
|
|
cache: "operational".to_string(),
|
||
|
|
database: "operational".to_string(),
|
||
|
|
},
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
/// API information endpoint
|
||
|
|
async fn api_info() -> impl IntoResponse {
|
||
|
|
Json(serde_json::json!({
|
||
|
|
"name": "Provisioning RAG API",
|
||
|
|
"version": env!("CARGO_PKG_VERSION"),
|
||
|
|
"description": "REST API for Retrieval-Augmented Generation system",
|
||
|
|
"endpoints": {
|
||
|
|
"query": "POST /query",
|
||
|
|
"streaming": "POST /query/stream",
|
||
|
|
"batch": "POST /batch",
|
||
|
|
"conversation": "POST /conversation",
|
||
|
|
"cache": "GET /cache/stats",
|
||
|
|
"health": "GET /health"
|
||
|
|
}
|
||
|
|
}))
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Query handler for single questions
|
||
|
|
async fn query_handler(
|
||
|
|
State(state): State<ApiState>,
|
||
|
|
Json(request): Json<QueryRequest>,
|
||
|
|
) -> Response {
|
||
|
|
match handle_query(&state, request).await {
|
||
|
|
Ok(response) => (StatusCode::OK, Json(response)).into_response(),
|
||
|
|
Err(e) => error_response(e).into_response(),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Stream handler for real-time responses
|
||
|
|
async fn stream_handler(
|
||
|
|
State(_state): State<ApiState>,
|
||
|
|
Json(_request): Json<QueryRequest>,
|
||
|
|
) -> impl IntoResponse {
|
||
|
|
// Streaming implementation for Phase 8
|
||
|
|
StatusCode::NOT_IMPLEMENTED
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Batch processing handler
|
||
|
|
async fn batch_handler(
|
||
|
|
State(state): State<ApiState>,
|
||
|
|
Json(request): Json<BatchQueryRequest>,
|
||
|
|
) -> Response {
|
||
|
|
match handle_batch(&state, request).await {
|
||
|
|
Ok(response) => (StatusCode::OK, Json(response)).into_response(),
|
||
|
|
Err(e) => error_response(e).into_response(),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Batch status handler
|
||
|
|
async fn batch_status_handler(
|
||
|
|
State(_state): State<ApiState>,
|
||
|
|
Path(_job_id): Path<String>,
|
||
|
|
) -> impl IntoResponse {
|
||
|
|
// Batch status tracking for Phase 8
|
||
|
|
StatusCode::NOT_IMPLEMENTED
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Conversation handler for multi-turn Q&A
|
||
|
|
async fn conversation_handler(
|
||
|
|
State(state): State<ApiState>,
|
||
|
|
Json(request): Json<ConversationRequest>,
|
||
|
|
) -> Response {
|
||
|
|
match handle_conversation(&state, request).await {
|
||
|
|
Ok(response) => (StatusCode::OK, Json(response)).into_response(),
|
||
|
|
Err(e) => error_response(e).into_response(),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Conversation history handler
|
||
|
|
async fn conversation_history_handler(
|
||
|
|
State(_state): State<ApiState>,
|
||
|
|
Path(_conversation_id): Path<String>,
|
||
|
|
) -> impl IntoResponse {
|
||
|
|
// Conversation history retrieval for Phase 8
|
||
|
|
StatusCode::NOT_IMPLEMENTED
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Cache statistics handler
|
||
|
|
async fn cache_stats_handler(State(_state): State<ApiState>) -> impl IntoResponse {
|
||
|
|
// Cache stats retrieval
|
||
|
|
(
|
||
|
|
StatusCode::OK,
|
||
|
|
Json(CacheStatsResponse {
|
||
|
|
items_in_cache: 0,
|
||
|
|
hits: 0,
|
||
|
|
misses: 0,
|
||
|
|
hit_rate: 0.0,
|
||
|
|
}),
|
||
|
|
)
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Cache clear handler
|
||
|
|
async fn cache_clear_handler(State(_state): State<ApiState>) -> impl IntoResponse {
|
||
|
|
// Clear cache
|
||
|
|
StatusCode::OK
|
||
|
|
}
|
||
|
|
|
||
|
|
/// List available tools handler
|
||
|
|
async fn list_tools_handler(State(_state): State<ApiState>) -> impl IntoResponse {
|
||
|
|
Json(serde_json::json!({
|
||
|
|
"tools": []
|
||
|
|
}))
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Execute tool handler
|
||
|
|
async fn execute_tool_handler(
|
||
|
|
State(_state): State<ApiState>,
|
||
|
|
Path(_tool_id): Path<String>,
|
||
|
|
Json(_input): Json<serde_json::Value>,
|
||
|
|
) -> impl IntoResponse {
|
||
|
|
// Tool execution for Phase 8
|
||
|
|
StatusCode::NOT_IMPLEMENTED
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Handle query request
|
||
|
|
async fn handle_query(state: &ApiState, request: QueryRequest) -> Result<QueryResponse> {
|
||
|
|
// Optimize query if context provided
|
||
|
|
let optimized_query = if let Some(context) = request.conversation_context {
|
||
|
|
state
|
||
|
|
.optimizer
|
||
|
|
.optimize_with_context(&request.query, Some(&context))?
|
||
|
|
} else {
|
||
|
|
state.optimizer.optimize(&request.query)?
|
||
|
|
};
|
||
|
|
|
||
|
|
// Generate response
|
||
|
|
let response = AgentResponse {
|
||
|
|
answer: format!("Answer for: {}", optimized_query.optimized),
|
||
|
|
sources: vec![],
|
||
|
|
confidence: 0.9,
|
||
|
|
context: optimized_query.optimized,
|
||
|
|
};
|
||
|
|
|
||
|
|
Ok(QueryResponse {
|
||
|
|
answer: response.answer,
|
||
|
|
sources: response.sources,
|
||
|
|
confidence: response.confidence,
|
||
|
|
context: response.context,
|
||
|
|
from_cache: None,
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Handle batch query request
|
||
|
|
async fn handle_batch(_state: &ApiState, request: BatchQueryRequest) -> Result<BatchResponse> {
|
||
|
|
let queries: Vec<BatchQuery> = request.queries.into_iter().map(BatchQuery::new).collect();
|
||
|
|
|
||
|
|
let job = BatchJob::new(queries)
|
||
|
|
.with_max_concurrent(request.max_concurrent)
|
||
|
|
.with_timeout(request.timeout_secs);
|
||
|
|
|
||
|
|
Ok(BatchResponse {
|
||
|
|
job_id: job.job_id.clone(),
|
||
|
|
results: vec![],
|
||
|
|
stats: BatchStats {
|
||
|
|
total_queries: job.queries.len(),
|
||
|
|
successful_queries: 0,
|
||
|
|
failed_queries: 0,
|
||
|
|
success_rate: 0.0,
|
||
|
|
total_duration_ms: 0,
|
||
|
|
},
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Handle conversation request
|
||
|
|
async fn handle_conversation(
|
||
|
|
_state: &ApiState,
|
||
|
|
request: ConversationRequest,
|
||
|
|
) -> Result<ConversationResponse> {
|
||
|
|
Ok(ConversationResponse {
|
||
|
|
response: format!("Response to: {}", request.message),
|
||
|
|
context: request.message,
|
||
|
|
followup_suggestions: None,
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Convert RagError to HTTP response
|
||
|
|
fn error_response(error: RagError) -> impl IntoResponse {
|
||
|
|
let (status, code) = match &error {
|
||
|
|
RagError::Config(_) => (StatusCode::BAD_REQUEST, "INVALID_CONFIG"),
|
||
|
|
RagError::Embedding(_) => (StatusCode::INTERNAL_SERVER_ERROR, "EMBEDDING_ERROR"),
|
||
|
|
RagError::Retrieval(_) => (StatusCode::INTERNAL_SERVER_ERROR, "RETRIEVAL_ERROR"),
|
||
|
|
RagError::LlmError(_) => (StatusCode::INTERNAL_SERVER_ERROR, "LLM_ERROR"),
|
||
|
|
RagError::Database(_) => (StatusCode::INTERNAL_SERVER_ERROR, "DB_ERROR"),
|
||
|
|
RagError::InvalidInput(_) => (StatusCode::BAD_REQUEST, "INVALID_INPUT"),
|
||
|
|
RagError::Internal(_) => (StatusCode::INTERNAL_SERVER_ERROR, "INTERNAL_ERROR"),
|
||
|
|
RagError::ToolError(_) => (StatusCode::INTERNAL_SERVER_ERROR, "TOOL_ERROR"),
|
||
|
|
RagError::Context(_) => (StatusCode::BAD_REQUEST, "CONTEXT_ERROR"),
|
||
|
|
RagError::Io(_) => (StatusCode::INTERNAL_SERVER_ERROR, "IO_ERROR"),
|
||
|
|
RagError::Surrealdb(_) => (StatusCode::INTERNAL_SERVER_ERROR, "DB_ERROR"),
|
||
|
|
RagError::Json(_) => (StatusCode::BAD_REQUEST, "JSON_ERROR"),
|
||
|
|
RagError::Regex(_) => (StatusCode::BAD_REQUEST, "REGEX_ERROR"),
|
||
|
|
RagError::Http(_) => (StatusCode::INTERNAL_SERVER_ERROR, "HTTP_ERROR"),
|
||
|
|
RagError::NotFound(_) => (StatusCode::NOT_FOUND, "NOT_FOUND"),
|
||
|
|
};
|
||
|
|
|
||
|
|
let response = ErrorResponse {
|
||
|
|
error: error.to_string(),
|
||
|
|
code: code.to_string(),
|
||
|
|
status: status.as_u16(),
|
||
|
|
};
|
||
|
|
|
||
|
|
(status, Json(response)).into_response()
|
||
|
|
}
|
||
|
|
|
||
|
|
#[cfg(test)]
|
||
|
|
mod tests {
|
||
|
|
use super::*;
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_query_request_serialization() {
|
||
|
|
let request = QueryRequest {
|
||
|
|
query: "What is Kubernetes?".to_string(),
|
||
|
|
conversation_context: None,
|
||
|
|
use_hybrid_search: true,
|
||
|
|
num_results: 5,
|
||
|
|
};
|
||
|
|
|
||
|
|
let json = serde_json::to_string(&request).unwrap();
|
||
|
|
assert!(json.contains("Kubernetes"));
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_error_response_status_codes() {
|
||
|
|
let error = RagError::Config("test".into());
|
||
|
|
let (status, code) = match &error {
|
||
|
|
RagError::Config(_) => (StatusCode::BAD_REQUEST, "INVALID_CONFIG"),
|
||
|
|
_ => (StatusCode::INTERNAL_SERVER_ERROR, "ERROR"),
|
||
|
|
};
|
||
|
|
assert_eq!(status, StatusCode::BAD_REQUEST);
|
||
|
|
assert_eq!(code, "INVALID_CONFIG");
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_batch_stats_calculation() {
|
||
|
|
let stats = BatchStats {
|
||
|
|
total_queries: 10,
|
||
|
|
successful_queries: 9,
|
||
|
|
failed_queries: 1,
|
||
|
|
success_rate: 0.9,
|
||
|
|
total_duration_ms: 5000,
|
||
|
|
};
|
||
|
|
|
||
|
|
assert_eq!(stats.total_queries, 10);
|
||
|
|
assert_eq!(stats.success_rate, 0.9);
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_default_batch_parameters() {
|
||
|
|
let request = BatchQueryRequest {
|
||
|
|
queries: vec!["test".to_string()],
|
||
|
|
max_concurrent: default_max_concurrent(),
|
||
|
|
timeout_secs: default_timeout_secs(),
|
||
|
|
};
|
||
|
|
|
||
|
|
assert_eq!(request.max_concurrent, 5);
|
||
|
|
assert_eq!(request.timeout_secs, 30);
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_health_response_structure() {
|
||
|
|
let health = HealthResponse {
|
||
|
|
status: "healthy".to_string(),
|
||
|
|
version: "0.1.0".to_string(),
|
||
|
|
components: ComponentHealth {
|
||
|
|
agent: "operational".to_string(),
|
||
|
|
cache: "operational".to_string(),
|
||
|
|
database: "operational".to_string(),
|
||
|
|
},
|
||
|
|
};
|
||
|
|
|
||
|
|
assert_eq!(health.status, "healthy");
|
||
|
|
assert_eq!(health.components.agent, "operational");
|
||
|
|
}
|
||
|
|
}
|