chore: add typedialog-ai AI with RAG adn knowledge graph
This commit is contained in:
parent
34508cddf4
commit
01980c9b8d
62
crates/typedialog-ai/Cargo.toml
Normal file
62
crates/typedialog-ai/Cargo.toml
Normal file
@ -0,0 +1,62 @@
|
||||
[package]
|
||||
name = "typedialog-ai"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
repository.workspace = true
|
||||
license.workspace = true
|
||||
description = "AI-powered configuration assistant backend and microservice for TypeDialog"
|
||||
|
||||
[dependencies]
|
||||
# Internal
|
||||
typedialog-core = { path = "../typedialog-core", features = ["ai_backend"] }
|
||||
|
||||
# Workspace dependencies (shared with other crates)
|
||||
tokio = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
surrealdb = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
serde_yaml = { workspace = true }
|
||||
toml = { workspace = true }
|
||||
clap = { workspace = true }
|
||||
dialoguer = { workspace = true }
|
||||
colored = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
dirs = { workspace = true }
|
||||
|
||||
# Web and HTTP dependencies (now aligned with workspace versions)
|
||||
# Code migrated to support workspace versions:
|
||||
# - axum: Upgraded from 0.7 to 0.8.8 (WebSocket Message::Text now uses Utf8Bytes)
|
||||
# - reqwest: Using workspace 0.12 (streaming API compatible)
|
||||
# - tower/tower-http: Aligned with axum 0.8.8
|
||||
reqwest = { workspace = true, features = ["json", "stream"] }
|
||||
axum = { workspace = true, features = ["ws"] }
|
||||
tower = { workspace = true }
|
||||
tower-http = { workspace = true, features = ["cors", "trace"] }
|
||||
|
||||
[features]
|
||||
default = ["openai"]
|
||||
openai = []
|
||||
anthropic = []
|
||||
ollama = []
|
||||
all-providers = ["openai", "anthropic", "ollama"]
|
||||
|
||||
[lib]
|
||||
name = "typedialog_ai"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "typedialog-ai"
|
||||
path = "src/main.rs"
|
||||
|
||||
[package.metadata.binstall]
|
||||
pkg-url = "{ repo }/releases/download/v{ version }/typedialog-{ target }.tar.gz"
|
||||
bin-dir = "bin/{ bin }"
|
||||
pkg-fmt = "tgz"
|
||||
131
crates/typedialog-ai/src/api/error.rs
Normal file
131
crates/typedialog-ai/src/api/error.rs
Normal file
@ -0,0 +1,131 @@
|
||||
//! API error handling
|
||||
|
||||
use super::types::ErrorResponse;
|
||||
use axum::{
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
|
||||
/// API error type
|
||||
#[derive(Debug)]
|
||||
pub enum ApiError {
|
||||
/// Conversation not found
|
||||
ConversationNotFound(String),
|
||||
|
||||
/// Invalid request
|
||||
InvalidRequest(String),
|
||||
|
||||
/// Database operation failed
|
||||
DatabaseError(String),
|
||||
|
||||
/// LLM operation failed
|
||||
LlmError(String),
|
||||
|
||||
/// Schema analysis failed
|
||||
SchemaError(String),
|
||||
|
||||
/// Configuration generation failed
|
||||
GenerationError(String),
|
||||
|
||||
/// Internal server error
|
||||
InternalError(String),
|
||||
}
|
||||
|
||||
impl ApiError {
|
||||
/// Get error code for client handling
|
||||
fn code(&self) -> &'static str {
|
||||
match self {
|
||||
ApiError::ConversationNotFound(_) => "NOT_FOUND",
|
||||
ApiError::InvalidRequest(_) => "INVALID_REQUEST",
|
||||
ApiError::DatabaseError(_) => "DATABASE_ERROR",
|
||||
ApiError::LlmError(_) => "LLM_ERROR",
|
||||
ApiError::SchemaError(_) => "SCHEMA_ERROR",
|
||||
ApiError::GenerationError(_) => "GENERATION_ERROR",
|
||||
ApiError::InternalError(_) => "INTERNAL_ERROR",
|
||||
}
|
||||
}
|
||||
|
||||
/// Get HTTP status code
|
||||
fn status(&self) -> StatusCode {
|
||||
match self {
|
||||
ApiError::ConversationNotFound(_) => StatusCode::NOT_FOUND,
|
||||
ApiError::InvalidRequest(_) => StatusCode::BAD_REQUEST,
|
||||
ApiError::DatabaseError(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
ApiError::LlmError(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
ApiError::SchemaError(_) => StatusCode::BAD_REQUEST,
|
||||
ApiError::GenerationError(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
ApiError::InternalError(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get error message
|
||||
fn message(&self) -> String {
|
||||
match self {
|
||||
ApiError::ConversationNotFound(id) => format!("Conversation '{}' not found", id),
|
||||
ApiError::InvalidRequest(msg) => msg.clone(),
|
||||
ApiError::DatabaseError(msg) => format!("Database error: {}", msg),
|
||||
ApiError::LlmError(msg) => format!("LLM error: {}", msg),
|
||||
ApiError::SchemaError(msg) => format!("Schema error: {}", msg),
|
||||
ApiError::GenerationError(msg) => format!("Generation error: {}", msg),
|
||||
ApiError::InternalError(msg) => format!("Internal error: {}", msg),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for ApiError {
|
||||
fn into_response(self) -> Response {
|
||||
let status = self.status();
|
||||
let error_response = ErrorResponse::new(self.message(), self.code());
|
||||
|
||||
(status, Json(error_response)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<anyhow::Error> for ApiError {
|
||||
fn from(err: anyhow::Error) -> Self {
|
||||
ApiError::InternalError(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for ApiError {
|
||||
fn from(err: String) -> Self {
|
||||
ApiError::InternalError(err)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_conversation_not_found_code() {
|
||||
let err = ApiError::ConversationNotFound("123".to_string());
|
||||
assert_eq!(err.code(), "NOT_FOUND");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conversation_not_found_status() {
|
||||
let err = ApiError::ConversationNotFound("123".to_string());
|
||||
assert_eq!(err.status(), StatusCode::NOT_FOUND);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_request_code() {
|
||||
let err = ApiError::InvalidRequest("bad data".to_string());
|
||||
assert_eq!(err.code(), "INVALID_REQUEST");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_llm_error_code() {
|
||||
let err = ApiError::LlmError("api failed".to_string());
|
||||
assert_eq!(err.code(), "LLM_ERROR");
|
||||
assert_eq!(err.status(), StatusCode::INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_message() {
|
||||
let err = ApiError::ConversationNotFound("conv-123".to_string());
|
||||
assert!(err.message().contains("conv-123"));
|
||||
}
|
||||
}
|
||||
15
crates/typedialog-ai/src/api/mod.rs
Normal file
15
crates/typedialog-ai/src/api/mod.rs
Normal file
@ -0,0 +1,15 @@
|
||||
//! REST API module
|
||||
//!
|
||||
//! Provides HTTP and WebSocket endpoints for the AI configuration assistant service.
|
||||
|
||||
pub mod error;
|
||||
pub mod rest;
|
||||
pub mod types;
|
||||
pub mod websocket;
|
||||
|
||||
#[allow(unused_imports)]
|
||||
pub use error::ApiError;
|
||||
#[allow(unused_imports)]
|
||||
pub use rest::{create_router, AppState};
|
||||
#[allow(unused_imports)]
|
||||
pub use websocket::{WsMessage, WsResponse};
|
||||
311
crates/typedialog-ai/src/api/rest.rs
Normal file
311
crates/typedialog-ai/src/api/rest.rs
Normal file
@ -0,0 +1,311 @@
|
||||
//! REST API handlers for TypeDialog AI Service
|
||||
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
http::StatusCode,
|
||||
routing::{delete, get, post},
|
||||
Json, Router,
|
||||
};
|
||||
use chrono::Utc;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use super::{error::ApiError, types::*, websocket};
|
||||
use crate::assistant::ConfigAssistant;
|
||||
use crate::storage::SurrealDbClient;
|
||||
|
||||
/// Shared application state
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
/// Database client
|
||||
pub db: Arc<SurrealDbClient>,
|
||||
|
||||
/// Currently active assistants by conversation ID
|
||||
pub assistants: Arc<Mutex<std::collections::HashMap<String, ConfigAssistant>>>,
|
||||
|
||||
/// Server start time for uptime tracking
|
||||
pub start_time: std::time::Instant,
|
||||
}
|
||||
|
||||
/// Create Axum router with all routes
|
||||
pub fn create_router(state: AppState) -> Router {
|
||||
use crate::web_ui;
|
||||
|
||||
Router::new()
|
||||
// Web UI
|
||||
.route("/", get(web_ui::index))
|
||||
.route("/index.html", get(web_ui::index))
|
||||
.route("/styles.css", get(web_ui::styles))
|
||||
.route("/app.js", get(web_ui::app))
|
||||
// Health check
|
||||
.route("/health", get(health_check))
|
||||
// Conversation management
|
||||
.route("/conversations", post(create_conversation))
|
||||
.route("/conversations/{id}", get(get_conversation))
|
||||
.route("/conversations/{id}", delete(delete_conversation))
|
||||
// Message handling
|
||||
.route("/conversations/{id}/messages", post(send_message))
|
||||
// Configuration generation
|
||||
.route("/conversations/{id}/generate", post(generate_config))
|
||||
// Suggestions
|
||||
.route(
|
||||
"/conversations/{id}/suggestions/{field}",
|
||||
get(get_suggestions),
|
||||
)
|
||||
// WebSocket streaming
|
||||
.route("/ws/{id}", axum::routing::get(websocket::handle_websocket))
|
||||
.with_state(state)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// HEALTH CHECK
|
||||
// ============================================================================
|
||||
|
||||
/// Health check endpoint
|
||||
pub async fn health_check(State(state): State<AppState>) -> Json<HealthResponse> {
|
||||
let uptime = state.start_time.elapsed().as_secs();
|
||||
|
||||
Json(HealthResponse {
|
||||
status: "healthy".to_string(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
uptime,
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// CONVERSATION MANAGEMENT
|
||||
// ============================================================================
|
||||
|
||||
/// Create a new conversation
|
||||
///
|
||||
/// # Request
|
||||
/// ```json
|
||||
/// {"schema_id": "my-schema"}
|
||||
/// ```
|
||||
///
|
||||
/// # Response
|
||||
/// ```json
|
||||
/// {
|
||||
/// "conversation_id": "uuid",
|
||||
/// "schema_id": "my-schema",
|
||||
/// "message": "Conversation started"
|
||||
/// }
|
||||
/// ```
|
||||
pub async fn create_conversation(
|
||||
State(state): State<AppState>,
|
||||
Json(req): Json<CreateConversationRequest>,
|
||||
) -> Result<(StatusCode, Json<CreateConversationResponse>), ApiError> {
|
||||
tracing::info!(schema_id = %req.schema_id, "Creating new conversation");
|
||||
|
||||
// Create conversation in database
|
||||
let conv_id = state
|
||||
.db
|
||||
.create_conversation(&req.schema_id)
|
||||
.await
|
||||
.map_err(|e| ApiError::DatabaseError(e.to_string()))?;
|
||||
|
||||
Ok((
|
||||
StatusCode::CREATED,
|
||||
Json(CreateConversationResponse {
|
||||
conversation_id: conv_id,
|
||||
schema_id: req.schema_id,
|
||||
message: "Conversation created successfully".to_string(),
|
||||
}),
|
||||
))
|
||||
}
|
||||
|
||||
/// Get conversation information
|
||||
pub async fn get_conversation(
|
||||
State(_state): State<AppState>,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Json<ConversationInfo>, ApiError> {
|
||||
tracing::debug!(conversation_id = %id, "Getting conversation info");
|
||||
|
||||
// TODO: Load conversation from database
|
||||
Ok(Json(ConversationInfo {
|
||||
id: id.clone(),
|
||||
schema_id: "schema".to_string(),
|
||||
status: "active".to_string(),
|
||||
created_at: Utc::now().to_rfc3339(),
|
||||
updated_at: Utc::now().to_rfc3339(),
|
||||
}))
|
||||
}
|
||||
|
||||
/// Delete a conversation
|
||||
pub async fn delete_conversation(
|
||||
State(_state): State<AppState>,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<StatusCode, ApiError> {
|
||||
tracing::info!(conversation_id = %id, "Deleting conversation");
|
||||
|
||||
// TODO: Delete conversation from database
|
||||
Ok(StatusCode::NO_CONTENT)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// MESSAGE HANDLING
|
||||
// ============================================================================
|
||||
|
||||
/// Send a message to the assistant
|
||||
///
|
||||
/// # Request
|
||||
/// ```json
|
||||
/// {"message": "What port should I use?"}
|
||||
/// ```
|
||||
///
|
||||
/// # Response
|
||||
/// ```json
|
||||
/// {
|
||||
/// "text": "For a web server, port 8080 is commonly used...",
|
||||
/// "suggestions": [
|
||||
/// {
|
||||
/// "field": "port",
|
||||
/// "value": "8080",
|
||||
/// "reasoning": "Common web server port",
|
||||
/// "confidence": 0.92
|
||||
/// }
|
||||
/// ],
|
||||
/// "message_id": "msg-uuid",
|
||||
/// "rag_context": "..."
|
||||
/// }
|
||||
/// ```
|
||||
pub async fn send_message(
|
||||
State(_state): State<AppState>,
|
||||
Path(id): Path<String>,
|
||||
Json(req): Json<SendMessageRequest>,
|
||||
) -> Result<Json<SendMessageResponse>, ApiError> {
|
||||
tracing::debug!(conversation_id = %id, message_len = req.message.len(), "Processing message");
|
||||
|
||||
if req.message.trim().is_empty() {
|
||||
return Err(ApiError::InvalidRequest(
|
||||
"Message cannot be empty".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// TODO: Load assistant from state, call send_message(), return response
|
||||
Err(ApiError::InternalError(
|
||||
"Assistant not initialized".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// CONFIGURATION GENERATION
|
||||
// ============================================================================
|
||||
|
||||
/// Generate configuration from conversation
|
||||
///
|
||||
/// # Request
|
||||
/// ```json
|
||||
/// {"format": "json"}
|
||||
/// ```
|
||||
///
|
||||
/// # Response
|
||||
/// ```json
|
||||
/// {
|
||||
/// "content": "{\"port\": 8080, ...}",
|
||||
/// "format": "json",
|
||||
/// "is_valid": true,
|
||||
/// "errors": []
|
||||
/// }
|
||||
/// ```
|
||||
pub async fn generate_config(
|
||||
State(_state): State<AppState>,
|
||||
Path(id): Path<String>,
|
||||
Json(req): Json<GenerateConfigRequest>,
|
||||
) -> Result<Json<GenerateConfigResponse>, ApiError> {
|
||||
tracing::info!(conversation_id = %id, format = %req.format, "Generating configuration");
|
||||
|
||||
// Validate format
|
||||
if !matches!(req.format.as_str(), "json" | "yaml" | "toml") {
|
||||
return Err(ApiError::InvalidRequest(format!(
|
||||
"Invalid format '{}'. Must be 'json', 'yaml', or 'toml'",
|
||||
req.format
|
||||
)));
|
||||
}
|
||||
|
||||
// TODO: Load assistant from state, call generate_config(), return response
|
||||
Err(ApiError::InternalError(
|
||||
"Assistant not initialized".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// SUGGESTIONS
|
||||
// ============================================================================
|
||||
|
||||
/// Get field suggestions
|
||||
///
|
||||
/// # Response
|
||||
/// ```json
|
||||
/// {
|
||||
/// "field": "port",
|
||||
/// "suggestions": [
|
||||
/// {
|
||||
/// "field": "port",
|
||||
/// "value": "8080",
|
||||
/// "reasoning": "Found in examples",
|
||||
/// "confidence": 0.85
|
||||
/// }
|
||||
/// ]
|
||||
/// }
|
||||
/// ```
|
||||
pub async fn get_suggestions(
|
||||
State(_state): State<AppState>,
|
||||
Path((id, field)): Path<(String, String)>,
|
||||
) -> Result<Json<SuggestionsResponse>, ApiError> {
|
||||
tracing::debug!(conversation_id = %id, field = %field, "Getting suggestions");
|
||||
|
||||
// TODO: Load assistant from state, call suggest_values(), return response
|
||||
Err(ApiError::InternalError(
|
||||
"Assistant not initialized".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
async fn create_test_state() -> Result<AppState, Box<dyn std::error::Error>> {
|
||||
let db = Arc::new(SurrealDbClient::new("memory://", "default", "test").await?);
|
||||
Ok(AppState {
|
||||
db,
|
||||
assistants: Arc::new(Mutex::new(std::collections::HashMap::new())),
|
||||
start_time: std::time::Instant::now(),
|
||||
})
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_router_creation() {
|
||||
let state = create_test_state().await.unwrap();
|
||||
let _router = create_router(state);
|
||||
// Router created successfully
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_conversation_request_validation() {
|
||||
let req = CreateConversationRequest {
|
||||
schema_id: "test-schema".to_string(),
|
||||
};
|
||||
assert!(!req.schema_id.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_config_request_format_validation() {
|
||||
let valid_formats = vec!["json", "yaml", "toml"];
|
||||
for format in valid_formats {
|
||||
let req = GenerateConfigRequest {
|
||||
format: format.to_string(),
|
||||
};
|
||||
assert!(matches!(req.format.as_str(), "json" | "yaml" | "toml"));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_send_message_request_serialization() {
|
||||
let req = SendMessageRequest {
|
||||
message: "Test message".to_string(),
|
||||
};
|
||||
let json = serde_json::to_string(&req).unwrap();
|
||||
let deserialized: SendMessageRequest = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(deserialized.message, "Test message");
|
||||
}
|
||||
}
|
||||
220
crates/typedialog-ai/src/api/types.rs
Normal file
220
crates/typedialog-ai/src/api/types.rs
Normal file
@ -0,0 +1,220 @@
|
||||
//! API request and response types
|
||||
|
||||
use crate::assistant::{AssistantResponse, FieldSuggestion, GeneratedConfig};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Request to start a new conversation
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CreateConversationRequest {
|
||||
/// Schema identifier or name
|
||||
pub schema_id: String,
|
||||
}
|
||||
|
||||
/// Response when conversation is created
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CreateConversationResponse {
|
||||
/// Unique conversation identifier
|
||||
pub conversation_id: String,
|
||||
|
||||
/// Schema identifier
|
||||
pub schema_id: String,
|
||||
|
||||
/// Human-readable message
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
/// Request to send a message to the assistant
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SendMessageRequest {
|
||||
/// User message text
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
/// Response from sending a message
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SendMessageResponse {
|
||||
/// Assistant's text response
|
||||
pub text: String,
|
||||
|
||||
/// Field suggestions extracted from response
|
||||
pub suggestions: Vec<FieldSuggestion>,
|
||||
|
||||
/// Message ID for tracking
|
||||
pub message_id: String,
|
||||
|
||||
/// RAG context used for generation
|
||||
pub rag_context: Option<String>,
|
||||
}
|
||||
|
||||
impl From<AssistantResponse> for SendMessageResponse {
|
||||
fn from(resp: AssistantResponse) -> Self {
|
||||
SendMessageResponse {
|
||||
text: resp.text,
|
||||
suggestions: resp.suggestions,
|
||||
message_id: resp.message_id,
|
||||
rag_context: resp.rag_context,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Request to get field suggestions
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SuggestionsRequest {
|
||||
/// Field name to get suggestions for
|
||||
pub field: String,
|
||||
}
|
||||
|
||||
/// Response with field suggestions
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SuggestionsResponse {
|
||||
/// Field name
|
||||
pub field: String,
|
||||
|
||||
/// List of suggestions
|
||||
pub suggestions: Vec<FieldSuggestion>,
|
||||
}
|
||||
|
||||
/// Request to generate final configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GenerateConfigRequest {
|
||||
/// Output format: "json", "yaml", or "toml"
|
||||
pub format: String,
|
||||
}
|
||||
|
||||
/// Response with generated configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GenerateConfigResponse {
|
||||
/// Generated configuration content
|
||||
pub content: String,
|
||||
|
||||
/// Configuration format
|
||||
pub format: String,
|
||||
|
||||
/// Whether configuration is valid
|
||||
pub is_valid: bool,
|
||||
|
||||
/// Validation errors if any
|
||||
pub errors: Vec<String>,
|
||||
}
|
||||
|
||||
impl From<GeneratedConfig> for GenerateConfigResponse {
|
||||
fn from(config: GeneratedConfig) -> Self {
|
||||
GenerateConfigResponse {
|
||||
content: config.content,
|
||||
format: config.format,
|
||||
is_valid: config.is_valid,
|
||||
errors: config.errors,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Conversation metadata
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ConversationInfo {
|
||||
/// Conversation ID
|
||||
pub id: String,
|
||||
|
||||
/// Schema ID
|
||||
pub schema_id: String,
|
||||
|
||||
/// Conversation status
|
||||
pub status: String,
|
||||
|
||||
/// Created timestamp
|
||||
pub created_at: String,
|
||||
|
||||
/// Updated timestamp
|
||||
pub updated_at: String,
|
||||
}
|
||||
|
||||
/// Health check response
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HealthResponse {
|
||||
/// Service status
|
||||
pub status: String,
|
||||
|
||||
/// Service version
|
||||
pub version: String,
|
||||
|
||||
/// Uptime in seconds
|
||||
pub uptime: u64,
|
||||
}
|
||||
|
||||
/// Error response
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ErrorResponse {
|
||||
/// Error message
|
||||
pub error: String,
|
||||
|
||||
/// Error code (for clients to handle specific errors)
|
||||
pub code: String,
|
||||
|
||||
/// Optional additional details
|
||||
pub details: Option<String>,
|
||||
}
|
||||
|
||||
impl ErrorResponse {
|
||||
/// Create a new error response
|
||||
pub fn new(error: impl Into<String>, code: impl Into<String>) -> Self {
|
||||
ErrorResponse {
|
||||
error: error.into(),
|
||||
code: code.into(),
|
||||
details: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add details to error response
|
||||
pub fn with_details(mut self, details: impl Into<String>) -> Self {
|
||||
self.details = Some(details.into());
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_create_conversation_request_serialization() {
|
||||
let req = CreateConversationRequest {
|
||||
schema_id: "test-schema".to_string(),
|
||||
};
|
||||
let json = serde_json::to_string(&req).unwrap();
|
||||
let deserialized: CreateConversationRequest = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(deserialized.schema_id, "test-schema");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_send_message_request_serialization() {
|
||||
let req = SendMessageRequest {
|
||||
message: "What port should I use?".to_string(),
|
||||
};
|
||||
let json = serde_json::to_string(&req).unwrap();
|
||||
let deserialized: SendMessageRequest = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(deserialized.message, "What port should I use?");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_response_creation() {
|
||||
let error = ErrorResponse::new("Invalid request", "INVALID_REQUEST");
|
||||
assert_eq!(error.error, "Invalid request");
|
||||
assert_eq!(error.code, "INVALID_REQUEST");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_response_with_details() {
|
||||
let error = ErrorResponse::new("Invalid request", "INVALID_REQUEST")
|
||||
.with_details("Field 'message' is required");
|
||||
assert!(error.details.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_config_request_serialization() {
|
||||
let req = GenerateConfigRequest {
|
||||
format: "json".to_string(),
|
||||
};
|
||||
let json = serde_json::to_string(&req).unwrap();
|
||||
let deserialized: GenerateConfigRequest = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(deserialized.format, "json");
|
||||
}
|
||||
}
|
||||
333
crates/typedialog-ai/src/api/websocket.rs
Normal file
333
crates/typedialog-ai/src/api/websocket.rs
Normal file
@ -0,0 +1,333 @@
|
||||
//! WebSocket streaming for real-time LLM responses
|
||||
//!
|
||||
//! Provides token-by-token streaming of assistant responses via WebSocket.
|
||||
|
||||
use axum::{
|
||||
extract::{
|
||||
ws::{Message, WebSocket, WebSocketUpgrade},
|
||||
Path, State,
|
||||
},
|
||||
response::IntoResponse,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
|
||||
use super::rest::AppState;
|
||||
use crate::assistant::ConfigAssistant;
|
||||
|
||||
/// WebSocket message from client
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WsMessage {
|
||||
/// Message type: "message", "suggestion", "generate", etc.
|
||||
pub r#type: String,
|
||||
|
||||
/// Message content (user message or other data)
|
||||
pub content: String,
|
||||
|
||||
/// Additional data (e.g., format for config generation)
|
||||
pub data: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// WebSocket response sent to client
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WsResponse {
|
||||
/// Response type: "start", "chunk", "end", "error", "suggestion", etc.
|
||||
pub r#type: String,
|
||||
|
||||
/// Response content
|
||||
pub content: String,
|
||||
|
||||
/// Optional metadata (e.g., message_id, confidence)
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
impl WsResponse {
|
||||
/// Create a streaming chunk response
|
||||
pub fn chunk(content: impl Into<String>) -> Self {
|
||||
WsResponse {
|
||||
r#type: "chunk".to_string(),
|
||||
content: content.into(),
|
||||
metadata: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a start response
|
||||
pub fn start() -> Self {
|
||||
WsResponse {
|
||||
r#type: "start".to_string(),
|
||||
content: String::new(),
|
||||
metadata: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an end response with metadata
|
||||
pub fn end(metadata: Option<serde_json::Value>) -> Self {
|
||||
WsResponse {
|
||||
r#type: "end".to_string(),
|
||||
content: String::new(),
|
||||
metadata,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an error response
|
||||
pub fn error(error: impl Into<String>) -> Self {
|
||||
WsResponse {
|
||||
r#type: "error".to_string(),
|
||||
content: error.into(),
|
||||
metadata: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Serialize to JSON
|
||||
pub fn to_json(&self) -> String {
|
||||
serde_json::to_string(self).unwrap_or_else(|_| {
|
||||
serde_json::to_string(&WsResponse::error("Serialization failed")).unwrap()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle WebSocket upgrade and streaming
|
||||
///
|
||||
/// Accepts a WebSocket connection and streams LLM responses in real-time.
|
||||
/// The client should send JSON messages with:
|
||||
/// - type: "message", "generate", or "suggestions"
|
||||
/// - content: the user input
|
||||
/// - data: optional additional data (e.g., format for generation)
|
||||
pub async fn handle_websocket(
|
||||
ws: WebSocketUpgrade,
|
||||
Path(conversation_id): Path<String>,
|
||||
State(state): State<AppState>,
|
||||
) -> impl IntoResponse {
|
||||
ws.on_upgrade(|socket| handle_socket(socket, conversation_id, state))
|
||||
}
|
||||
|
||||
async fn handle_socket(mut socket: WebSocket, conversation_id: String, state: AppState) {
|
||||
// Get or create assistant for this conversation
|
||||
let mut assistant_opt = None;
|
||||
|
||||
while let Some(msg_result) = socket.recv().await {
|
||||
match msg_result {
|
||||
Ok(Message::Text(text)) => {
|
||||
// Parse incoming message
|
||||
match serde_json::from_str::<WsMessage>(&text) {
|
||||
Ok(ws_msg) => {
|
||||
// Handle message based on type
|
||||
match ws_msg.r#type.as_str() {
|
||||
"message" => {
|
||||
handle_message_type(
|
||||
&mut socket,
|
||||
&mut assistant_opt,
|
||||
&conversation_id,
|
||||
&state,
|
||||
&ws_msg,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
"generate" => {
|
||||
handle_generate_type(
|
||||
&mut socket,
|
||||
&mut assistant_opt,
|
||||
&conversation_id,
|
||||
&state,
|
||||
&ws_msg,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
"suggestions" => {
|
||||
handle_suggestions_type(
|
||||
&mut socket,
|
||||
&mut assistant_opt,
|
||||
&conversation_id,
|
||||
&state,
|
||||
&ws_msg,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
_ => {
|
||||
let error_resp = WsResponse::error(format!(
|
||||
"Unknown message type: {}",
|
||||
ws_msg.r#type
|
||||
));
|
||||
drop(
|
||||
socket
|
||||
.send(Message::Text(error_resp.to_json().into()))
|
||||
.await,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let error_resp =
|
||||
WsResponse::error(format!("Failed to parse message: {}", e));
|
||||
drop(
|
||||
socket
|
||||
.send(Message::Text(error_resp.to_json().into()))
|
||||
.await,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(Message::Close(_)) => {
|
||||
tracing::debug!(conv_id = %conversation_id, "WebSocket closed by client");
|
||||
break;
|
||||
}
|
||||
Ok(_) => {
|
||||
// Ignore other message types
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(conv_id = %conversation_id, error = %e, "WebSocket error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_message_type(
|
||||
socket: &mut WebSocket,
|
||||
_assistant_opt: &mut Option<ConfigAssistant>,
|
||||
_conversation_id: &str,
|
||||
_state: &AppState,
|
||||
ws_msg: &WsMessage,
|
||||
) {
|
||||
// Send start marker
|
||||
let start_resp = WsResponse::start();
|
||||
if socket
|
||||
.send(Message::Text(start_resp.to_json().into()))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// Get or create assistant (for now, just send the message back with chunks)
|
||||
// In a full implementation, this would get the assistant from state
|
||||
// and call stream_message() on it
|
||||
|
||||
// Simulate streaming response by splitting the content into words
|
||||
let words: Vec<&str> = ws_msg.content.split_whitespace().collect();
|
||||
|
||||
for (i, word) in words.iter().enumerate() {
|
||||
let chunk_resp = WsResponse::chunk(if i == 0 {
|
||||
word.to_string()
|
||||
} else {
|
||||
format!(" {}", word)
|
||||
});
|
||||
|
||||
if socket
|
||||
.send(Message::Text(chunk_resp.to_json().into()))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// Small delay to simulate token streaming
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
|
||||
}
|
||||
|
||||
// Send end marker
|
||||
let end_resp = WsResponse::end(Some(json!({
|
||||
"type": "message",
|
||||
"status": "complete"
|
||||
})));
|
||||
drop(socket.send(Message::Text(end_resp.to_json().into())).await);
|
||||
}
|
||||
|
||||
async fn handle_generate_type(
|
||||
socket: &mut WebSocket,
|
||||
_assistant_opt: &mut Option<ConfigAssistant>,
|
||||
_conversation_id: &str,
|
||||
_state: &AppState,
|
||||
_ws_msg: &WsMessage,
|
||||
) {
|
||||
// Send generation response
|
||||
let resp = WsResponse::error("Configuration generation not yet implemented via WebSocket");
|
||||
drop(socket.send(Message::Text(resp.to_json().into())).await);
|
||||
}
|
||||
|
||||
async fn handle_suggestions_type(
|
||||
socket: &mut WebSocket,
|
||||
_assistant_opt: &mut Option<ConfigAssistant>,
|
||||
_conversation_id: &str,
|
||||
_state: &AppState,
|
||||
ws_msg: &WsMessage,
|
||||
) {
|
||||
// Get field name from content
|
||||
let field_name = &ws_msg.content;
|
||||
|
||||
// Send suggestions response with mock data
|
||||
let suggestions = json!([
|
||||
{
|
||||
"field": field_name,
|
||||
"value": "example_value",
|
||||
"reasoning": "Found in examples",
|
||||
"confidence": 0.85
|
||||
}
|
||||
]);
|
||||
|
||||
let resp = WsResponse {
|
||||
r#type: "suggestions".to_string(),
|
||||
content: serde_json::to_string(&suggestions).unwrap_or_default(),
|
||||
metadata: Some(json!({"field": field_name})),
|
||||
};
|
||||
|
||||
drop(socket.send(Message::Text(resp.to_json().into())).await);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_ws_message_serialization() {
|
||||
let msg = WsMessage {
|
||||
r#type: "message".to_string(),
|
||||
content: "What port?".to_string(),
|
||||
data: None,
|
||||
};
|
||||
let json = serde_json::to_string(&msg).unwrap();
|
||||
let deserialized: WsMessage = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(deserialized.r#type, "message");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ws_response_chunk() {
|
||||
let resp = WsResponse::chunk("port");
|
||||
assert_eq!(resp.r#type, "chunk");
|
||||
assert_eq!(resp.content, "port");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ws_response_start() {
|
||||
let resp = WsResponse::start();
|
||||
assert_eq!(resp.r#type, "start");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ws_response_end() {
|
||||
let resp = WsResponse::end(None);
|
||||
assert_eq!(resp.r#type, "end");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ws_response_error() {
|
||||
let resp = WsResponse::error("Something failed");
|
||||
assert_eq!(resp.r#type, "error");
|
||||
assert!(resp.content.contains("failed"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ws_response_to_json() {
|
||||
let resp = WsResponse::chunk("test");
|
||||
let json = resp.to_json();
|
||||
assert!(json.contains("chunk"));
|
||||
assert!(json.contains("test"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ws_response_with_metadata() {
|
||||
let metadata = Some(serde_json::json!({"message_id": "123"}));
|
||||
let resp = WsResponse::end(metadata);
|
||||
assert!(resp.metadata.is_some());
|
||||
}
|
||||
}
|
||||
664
crates/typedialog-ai/src/assistant/engine.rs
Normal file
664
crates/typedialog-ai/src/assistant/engine.rs
Normal file
@ -0,0 +1,664 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
//! AI Configuration Assistant Engine
|
||||
//!
|
||||
//! Core conversation loop integrating:
|
||||
//! - LLM provider (OpenAI, Claude, Ollama)
|
||||
//! - RAG system for example retrieval
|
||||
//! - Nickel schema understanding
|
||||
//! - Configuration generation and validation
|
||||
//! - Persistent conversation storage
|
||||
|
||||
use crate::llm::{
|
||||
prompts::system::config_assistant_system, rag_integration::format_rag_context,
|
||||
GenerationOptions, LlmProvider, Message as LlmMessage,
|
||||
};
|
||||
use crate::storage::{MessageRole, SurrealDbClient};
|
||||
use anyhow::Result;
|
||||
use chrono::Utc;
|
||||
use futures::stream::StreamExt;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use typedialog_core::ai::rag::RagSystem;
|
||||
|
||||
use super::schema_analyzer::AnalyzedSchema;
|
||||
|
||||
/// Structured response from the assistant
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AssistantResponse {
|
||||
/// The assistant's text response
|
||||
pub text: String,
|
||||
|
||||
/// Field suggestions extracted from the response
|
||||
pub suggestions: Vec<FieldSuggestion>,
|
||||
|
||||
/// RAG results that informed this response
|
||||
pub rag_context: Option<String>,
|
||||
|
||||
/// Message ID in storage
|
||||
pub message_id: String,
|
||||
}
|
||||
|
||||
/// Suggestion for a configuration field
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FieldSuggestion {
|
||||
/// Field name (e.g., "port", "database_url")
|
||||
pub field: String,
|
||||
|
||||
/// Suggested value
|
||||
pub value: String,
|
||||
|
||||
/// Explanation of the suggestion
|
||||
pub reasoning: String,
|
||||
|
||||
/// Confidence score (0-1)
|
||||
pub confidence: f32,
|
||||
}
|
||||
|
||||
/// Configuration generation result
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GeneratedConfig {
|
||||
/// Generated configuration (JSON, YAML, or TOML)
|
||||
pub content: String,
|
||||
|
||||
/// Format of the configuration
|
||||
pub format: String,
|
||||
|
||||
/// Whether the configuration is valid
|
||||
pub is_valid: bool,
|
||||
|
||||
/// Validation errors if any
|
||||
pub errors: Vec<String>,
|
||||
}
|
||||
|
||||
/// AI Configuration Assistant Engine
|
||||
///
|
||||
/// Manages multi-turn conversations for intelligent configuration generation
|
||||
/// using RAG-powered suggestions and LLM generation.
|
||||
pub struct ConfigAssistant {
|
||||
/// LLM provider for text generation
|
||||
llm: Arc<dyn LlmProvider>,
|
||||
|
||||
/// RAG system for retrieving relevant examples
|
||||
rag: RagSystem,
|
||||
|
||||
/// Analyzed schema for the current conversation
|
||||
schema: AnalyzedSchema,
|
||||
|
||||
/// Database client for persistence
|
||||
storage: Arc<SurrealDbClient>,
|
||||
|
||||
/// Current conversation ID
|
||||
conversation_id: String,
|
||||
|
||||
/// Conversation history for context
|
||||
history: Vec<ConversationTurn>,
|
||||
}
|
||||
|
||||
/// Single turn in a conversation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConversationTurn {
|
||||
/// User or assistant message
|
||||
role: MessageRole,
|
||||
|
||||
/// Message content
|
||||
content: String,
|
||||
|
||||
/// Timestamp
|
||||
timestamp: chrono::DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl ConfigAssistant {
|
||||
/// Create a new configuration assistant for a conversation
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `llm` - LLM provider (OpenAI, Claude, Ollama)
|
||||
/// * `rag` - Initialized RAG system
|
||||
/// * `schema` - Analyzed schema for the conversation
|
||||
/// * `storage` - Database client for persistence
|
||||
/// * `conversation_id` - Unique conversation identifier
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Configured assistant ready for use
|
||||
pub fn new(
|
||||
llm: Arc<dyn LlmProvider>,
|
||||
rag: RagSystem,
|
||||
schema: AnalyzedSchema,
|
||||
storage: Arc<SurrealDbClient>,
|
||||
conversation_id: String,
|
||||
) -> Self {
|
||||
ConfigAssistant {
|
||||
llm,
|
||||
rag,
|
||||
schema,
|
||||
storage,
|
||||
conversation_id,
|
||||
history: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Send a user message and get assistant response
|
||||
///
|
||||
/// The complete flow:
|
||||
/// 1. Store user message
|
||||
/// 2. Retrieve relevant examples via RAG
|
||||
/// 3. Format prompt with schema and RAG context
|
||||
/// 4. Generate response with LLM
|
||||
/// 5. Extract field suggestions
|
||||
/// 6. Store assistant response
|
||||
/// 7. Return structured response
|
||||
pub async fn send_message(&mut self, user_message: &str) -> Result<AssistantResponse> {
|
||||
tracing::debug!(conv_id = %self.conversation_id, "Processing user message");
|
||||
|
||||
// 1. Store user message
|
||||
let user_msg_id = self
|
||||
.storage
|
||||
.create_message(&self.conversation_id, MessageRole::User, user_message)
|
||||
.await?;
|
||||
self.history.push(ConversationTurn {
|
||||
role: MessageRole::User,
|
||||
content: user_message.to_string(),
|
||||
timestamp: Utc::now(),
|
||||
});
|
||||
|
||||
// 2. Retrieve relevant examples via RAG
|
||||
let rag_results = self.rag.retrieve(user_message)?;
|
||||
|
||||
// Convert to llm::rag_integration::RetrievalResult for formatting
|
||||
let llm_rag_results: Vec<_> = rag_results
|
||||
.iter()
|
||||
.map(|r| crate::llm::rag_integration::RetrievalResult {
|
||||
doc_id: r.doc_id.clone(),
|
||||
content: r.content.clone(),
|
||||
combined_score: r.combined_score,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let rag_context = format_rag_context(&llm_rag_results);
|
||||
|
||||
tracing::debug!(results = rag_results.len(), "RAG retrieval complete");
|
||||
|
||||
// 3. Format prompt with schema and RAG context
|
||||
let system_prompt = config_assistant_system();
|
||||
let schema_description = self.schema.format_for_prompt();
|
||||
|
||||
let context_section = if !rag_context.is_empty() {
|
||||
format!("## Relevant Examples\n\n{}\n\n", rag_context)
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let user_prompt = format!(
|
||||
"{}{}## Current Question\n\n{}",
|
||||
schema_description, context_section, user_message
|
||||
);
|
||||
|
||||
let mut messages = vec![LlmMessage::system(&system_prompt)];
|
||||
|
||||
// Add conversation history for context
|
||||
for turn in &self.history[..self.history.len().saturating_sub(1)] {
|
||||
match turn.role {
|
||||
MessageRole::User => {
|
||||
messages.push(LlmMessage::user(&turn.content));
|
||||
}
|
||||
MessageRole::Assistant => {
|
||||
messages.push(LlmMessage::assistant(&turn.content));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
messages.push(LlmMessage::user(&user_prompt));
|
||||
|
||||
// 4. Generate response with LLM
|
||||
let options = GenerationOptions::analytical();
|
||||
let response_text = self.llm.generate(&messages, &options).await?;
|
||||
|
||||
tracing::debug!("LLM generation complete");
|
||||
|
||||
// 5. Extract field suggestions
|
||||
let suggestions = self.extract_suggestions(&response_text)?;
|
||||
|
||||
// 6. Store assistant response
|
||||
let assistant_msg_id = self
|
||||
.storage
|
||||
.create_message(
|
||||
&self.conversation_id,
|
||||
MessageRole::Assistant,
|
||||
&response_text,
|
||||
)
|
||||
.await?;
|
||||
self.history.push(ConversationTurn {
|
||||
role: MessageRole::Assistant,
|
||||
content: response_text.clone(),
|
||||
timestamp: Utc::now(),
|
||||
});
|
||||
|
||||
// Link RAG results to user message for debugging/analytics
|
||||
let rag_ids: Vec<String> = rag_results.iter().map(|r| r.doc_id.clone()).collect();
|
||||
if !rag_ids.is_empty() {
|
||||
self.storage.link_rag_results(&user_msg_id, rag_ids).await?;
|
||||
}
|
||||
|
||||
Ok(AssistantResponse {
|
||||
text: response_text,
|
||||
suggestions,
|
||||
rag_context: Some(rag_context),
|
||||
message_id: assistant_msg_id,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get suggestions for a specific field
|
||||
///
|
||||
/// Uses RAG to find relevant examples and extract suggested values
|
||||
pub async fn suggest_values(&mut self, field_name: &str) -> Result<Vec<FieldSuggestion>> {
|
||||
if self.schema.find_field(field_name).is_none() {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
// Query RAG for this field
|
||||
let query = format!(
|
||||
"examples for {} configuration with {} field",
|
||||
self.schema.name, field_name
|
||||
);
|
||||
|
||||
let rag_results = self.rag.retrieve(&query)?;
|
||||
|
||||
// Extract values from RAG results
|
||||
let mut suggestions = Vec::new();
|
||||
|
||||
for result in rag_results.iter().take(3) {
|
||||
// Extract values from the content
|
||||
if let Some(value) = self.extract_field_value(&result.content, field_name) {
|
||||
suggestions.push(FieldSuggestion {
|
||||
field: field_name.to_string(),
|
||||
value,
|
||||
reasoning: format!("Found in example: {}", result.doc_id),
|
||||
confidence: result.combined_score,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(suggestions)
|
||||
}
|
||||
|
||||
/// Generate final configuration from conversation
|
||||
///
|
||||
/// Creates a complete configuration based on the collected fields
|
||||
/// and validates it against the schema
|
||||
pub async fn generate_config(&mut self, format: &str) -> Result<GeneratedConfig> {
|
||||
tracing::debug!(format = %format, "Generating configuration");
|
||||
|
||||
// Collect all values from conversation
|
||||
let collected_values = self.collect_field_values()?;
|
||||
|
||||
// Build generation prompt
|
||||
let generation_prompt = format!(
|
||||
"Based on our conversation, generate a {} configuration for {}.\n\
|
||||
Collected values: {}\n\
|
||||
Schema: {}\n\
|
||||
Ensure all required fields are included and all values match the schema constraints.",
|
||||
format,
|
||||
self.schema.name,
|
||||
serde_json::to_string(&collected_values)?,
|
||||
self.schema.format_for_prompt()
|
||||
);
|
||||
|
||||
#[allow(clippy::useless_vec)]
|
||||
let messages = vec![
|
||||
LlmMessage::system(config_assistant_system()),
|
||||
LlmMessage::user(&generation_prompt),
|
||||
];
|
||||
|
||||
// Generate configuration
|
||||
let options = GenerationOptions::config_generation();
|
||||
let config_content = self.llm.generate(&messages, &options).await?;
|
||||
|
||||
tracing::debug!("Configuration generated, validating");
|
||||
|
||||
// TODO: Validate configuration with Nickel
|
||||
// For now, mark as valid if it's not empty
|
||||
let is_valid = !config_content.trim().is_empty();
|
||||
|
||||
Ok(GeneratedConfig {
|
||||
content: config_content,
|
||||
format: format.to_string(),
|
||||
is_valid,
|
||||
errors: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
/// Get conversation history
|
||||
pub fn history(&self) -> &[ConversationTurn] {
|
||||
&self.history
|
||||
}
|
||||
|
||||
/// Get conversation ID
|
||||
pub fn conversation_id(&self) -> &str {
|
||||
&self.conversation_id
|
||||
}
|
||||
|
||||
/// Get schema
|
||||
pub fn schema(&self) -> &AnalyzedSchema {
|
||||
&self.schema
|
||||
}
|
||||
|
||||
/// Stream a response for real-time WebSocket delivery
|
||||
///
|
||||
/// Similar to send_message but returns a stream instead of full response
|
||||
/// for token-by-token delivery via WebSocket
|
||||
pub async fn stream_message(
|
||||
&mut self,
|
||||
user_message: &str,
|
||||
) -> Result<futures::stream::BoxStream<'static, String>> {
|
||||
tracing::debug!(conv_id = %self.conversation_id, "Streaming user message");
|
||||
|
||||
// Store user message
|
||||
let _user_msg_id = self
|
||||
.storage
|
||||
.create_message(&self.conversation_id, MessageRole::User, user_message)
|
||||
.await?;
|
||||
self.history.push(ConversationTurn {
|
||||
role: MessageRole::User,
|
||||
content: user_message.to_string(),
|
||||
timestamp: Utc::now(),
|
||||
});
|
||||
|
||||
// Retrieve relevant examples via RAG
|
||||
let rag_results = self.rag.retrieve(user_message)?;
|
||||
|
||||
// Convert to llm::rag_integration::RetrievalResult for formatting
|
||||
let llm_rag_results: Vec<_> = rag_results
|
||||
.iter()
|
||||
.map(|r| crate::llm::rag_integration::RetrievalResult {
|
||||
doc_id: r.doc_id.clone(),
|
||||
content: r.content.clone(),
|
||||
combined_score: r.combined_score,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let rag_context = format_rag_context(&llm_rag_results);
|
||||
|
||||
// Format prompt with schema and RAG context
|
||||
let system_prompt = config_assistant_system();
|
||||
let schema_description = self.schema.format_for_prompt();
|
||||
|
||||
let context_section = if !rag_context.is_empty() {
|
||||
format!("## Relevant Examples\n\n{}\n\n", rag_context)
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let user_prompt = format!(
|
||||
"{}{}## Current Question\n\n{}",
|
||||
schema_description, context_section, user_message
|
||||
);
|
||||
|
||||
let mut messages = vec![LlmMessage::system(&system_prompt)];
|
||||
|
||||
// Add conversation history for context
|
||||
for turn in &self.history[..self.history.len().saturating_sub(1)] {
|
||||
match turn.role {
|
||||
MessageRole::User => {
|
||||
messages.push(LlmMessage::user(&turn.content));
|
||||
}
|
||||
MessageRole::Assistant => {
|
||||
messages.push(LlmMessage::assistant(&turn.content));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
messages.push(LlmMessage::user(&user_prompt));
|
||||
|
||||
// Create stream
|
||||
let options = GenerationOptions::analytical();
|
||||
let stream = self.llm.stream(&messages, &options).await?;
|
||||
|
||||
// Map stream to handle Result types - take only Ok values
|
||||
let mapped_stream = stream
|
||||
.filter_map(|result| async move {
|
||||
match result {
|
||||
Ok(chunk) => Some(chunk),
|
||||
Err(e) => {
|
||||
tracing::error!("Stream error: {:?}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
})
|
||||
.boxed();
|
||||
|
||||
Ok(mapped_stream)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// PRIVATE HELPER METHODS
|
||||
// ============================================================================
|
||||
|
||||
/// Format conversation history for inclusion in prompts
|
||||
fn format_conversation_history(&self) -> String {
|
||||
if self.history.len() <= 1 {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
let mut formatted = String::from("## Conversation History\n\n");
|
||||
|
||||
for turn in &self.history[..self.history.len().saturating_sub(1)] {
|
||||
match turn.role {
|
||||
MessageRole::User => {
|
||||
formatted.push_str(&format!("**User:** {}\n\n", turn.content));
|
||||
}
|
||||
MessageRole::Assistant => {
|
||||
formatted.push_str(&format!("**Assistant:** {}\n\n", turn.content));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
formatted
|
||||
}
|
||||
|
||||
/// Extract field suggestions from LLM response
|
||||
fn extract_suggestions(&self, response: &str) -> Result<Vec<FieldSuggestion>> {
|
||||
let mut suggestions = Vec::new();
|
||||
|
||||
// Simple pattern matching for field suggestions
|
||||
// Looks for patterns like "field_name: value" or "field_name should be value"
|
||||
for field in self.schema.fields.iter() {
|
||||
if let Some(pos) = response.find(&field.flat_name) {
|
||||
// Look for value near the field name
|
||||
let context = &response[pos..std::cmp::min(pos + 200, response.len())];
|
||||
|
||||
if let Some(value) = self.extract_value_from_context(context, &field.flat_name) {
|
||||
suggestions.push(FieldSuggestion {
|
||||
field: field.flat_name.clone(),
|
||||
value,
|
||||
reasoning: "Extracted from LLM response".to_string(),
|
||||
confidence: 0.7,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(suggestions)
|
||||
}
|
||||
|
||||
/// Extract field value from RAG result content
|
||||
fn extract_field_value(&self, content: &str, field_name: &str) -> Option<String> {
|
||||
// Look for patterns like: field_name = "value" or field_name: value
|
||||
let patterns = vec![
|
||||
format!("{} = \"", field_name),
|
||||
format!("{} = '", field_name),
|
||||
format!("{} = ", field_name),
|
||||
format!("{}: ", field_name),
|
||||
];
|
||||
|
||||
for pattern in patterns {
|
||||
if let Some(pos) = content.find(&pattern) {
|
||||
let start = pos + pattern.len();
|
||||
let remainder = &content[start..];
|
||||
|
||||
// Extract value until quote or line end
|
||||
if pattern.contains("\"") {
|
||||
if let Some(end) = remainder.find('"') {
|
||||
return Some(remainder[..end].to_string());
|
||||
}
|
||||
} else if pattern.contains("'") {
|
||||
if let Some(end) = remainder.find('\'') {
|
||||
return Some(remainder[..end].to_string());
|
||||
}
|
||||
} else {
|
||||
// Find next whitespace or comma
|
||||
let end = remainder
|
||||
.find(|c: char| c.is_whitespace() || c == ',' || c == ';')
|
||||
.unwrap_or(remainder.len());
|
||||
let value = remainder[..end].trim().to_string();
|
||||
if !value.is_empty() {
|
||||
return Some(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Extract value from response context
|
||||
fn extract_value_from_context(&self, context: &str, _field_name: &str) -> Option<String> {
|
||||
// Simple extraction: look for colon followed by value
|
||||
if let Some(colon_pos) = context.find(':') {
|
||||
let after_colon = &context[colon_pos + 1..];
|
||||
let value = after_colon
|
||||
.split_whitespace()
|
||||
.next()
|
||||
.unwrap_or("")
|
||||
.trim_matches(|c: char| !c.is_alphanumeric() && c != '_' && c != '-' && c != '.');
|
||||
|
||||
if !value.is_empty() {
|
||||
return Some(value.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Collect field values from conversation history
|
||||
fn collect_field_values(&self) -> Result<serde_json::Value> {
|
||||
let mut values = serde_json::json!({});
|
||||
|
||||
// Extract values from all assistant responses
|
||||
for turn in &self.history {
|
||||
if turn.role == MessageRole::Assistant {
|
||||
for suggestion in self.extract_suggestions(&turn.content)? {
|
||||
values[suggestion.field] = serde_json::json!(suggestion.value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(values)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::assistant::schema_analyzer::SchemaAnalyzer;
|
||||
use typedialog_core::nickel::schema_ir::{NickelFieldIR, NickelSchemaIR, NickelType};
|
||||
|
||||
fn create_test_schema() -> AnalyzedSchema {
|
||||
let fields = vec![NickelFieldIR {
|
||||
path: vec!["port".to_string()],
|
||||
flat_name: "port".to_string(),
|
||||
alias: Some("Port".to_string()),
|
||||
nickel_type: NickelType::Number,
|
||||
doc: Some("Server port".to_string()),
|
||||
default: None,
|
||||
optional: false,
|
||||
contract: None,
|
||||
contract_call: None,
|
||||
group: None,
|
||||
fragment_marker: None,
|
||||
is_array_of_records: false,
|
||||
array_element_fields: None,
|
||||
encryption_metadata: None,
|
||||
}];
|
||||
|
||||
let schema_ir = NickelSchemaIR {
|
||||
name: "ServerConfig".to_string(),
|
||||
description: Some("Server configuration".to_string()),
|
||||
fields,
|
||||
};
|
||||
|
||||
SchemaAnalyzer::analyze_schema_ir(schema_ir).unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_assistant_response_serialization() {
|
||||
let response = AssistantResponse {
|
||||
text: "Configuration looks good".to_string(),
|
||||
suggestions: vec![FieldSuggestion {
|
||||
field: "port".to_string(),
|
||||
value: "8080".to_string(),
|
||||
reasoning: "Standard web port".to_string(),
|
||||
confidence: 0.9,
|
||||
}],
|
||||
rag_context: Some("Example context".to_string()),
|
||||
message_id: "msg-123".to_string(),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&response).unwrap();
|
||||
drop(serde_json::from_str::<AssistantResponse>(&json).unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generated_config_structure() {
|
||||
let config = GeneratedConfig {
|
||||
content: "port = 8080".to_string(),
|
||||
format: "toml".to_string(),
|
||||
is_valid: true,
|
||||
errors: vec![],
|
||||
};
|
||||
|
||||
assert!(config.is_valid);
|
||||
assert_eq!(config.format, "toml");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_field_suggestion_serialization() {
|
||||
let suggestion = FieldSuggestion {
|
||||
field: "database".to_string(),
|
||||
value: "localhost".to_string(),
|
||||
reasoning: "Default development server".to_string(),
|
||||
confidence: 0.85,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&suggestion).unwrap();
|
||||
let deserialized: FieldSuggestion = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(deserialized.field, "database");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_field_value_equals_quote() {
|
||||
let content = r#"port = "8080""#;
|
||||
|
||||
// Verify content parsing
|
||||
assert!(content.contains("8080"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conversation_turn_storage() {
|
||||
let turn = ConversationTurn {
|
||||
role: MessageRole::User,
|
||||
content: "What port should I use?".to_string(),
|
||||
timestamp: Utc::now(),
|
||||
};
|
||||
|
||||
assert_eq!(turn.role, MessageRole::User);
|
||||
assert!(turn.content.contains("port"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_role_equality() {
|
||||
assert_eq!(MessageRole::User, MessageRole::User);
|
||||
assert_ne!(MessageRole::User, MessageRole::Assistant);
|
||||
}
|
||||
}
|
||||
211
crates/typedialog-ai/src/assistant/indexer.rs
Normal file
211
crates/typedialog-ai/src/assistant/indexer.rs
Normal file
@ -0,0 +1,211 @@
|
||||
//! RAG Knowledge Base Indexer
|
||||
//!
|
||||
//! Indexes Nickel schemas and configuration examples into the RAG system
|
||||
//! for retrieval-augmented generation. Supports both in-memory and persistent
|
||||
//! indexing for efficient startup and search.
|
||||
|
||||
use anyhow::Result;
|
||||
use std::path::PathBuf;
|
||||
use typedialog_core::ai::rag::RagSystem;
|
||||
|
||||
/// Build knowledge base from available schemas and examples
|
||||
///
|
||||
/// Indexes all Nickel schema files and TOML configuration examples
|
||||
/// found in the project's examples directories.
|
||||
pub async fn build_knowledge_base(rag: &mut RagSystem) -> Result<()> {
|
||||
tracing::info!("Building RAG knowledge base from project examples");
|
||||
|
||||
let mut documents = Vec::new();
|
||||
let _doc_count = documents.len();
|
||||
|
||||
// Index Nickel schemas - looking in common example directories
|
||||
let schema_dirs = vec![
|
||||
"examples/07-nickel-generation",
|
||||
"examples/nickel",
|
||||
"examples/schemas",
|
||||
];
|
||||
|
||||
for dir in schema_dirs {
|
||||
if let Ok(entries) = std::fs::read_dir(dir) {
|
||||
for entry in entries.flatten() {
|
||||
let path = entry.path();
|
||||
if path.extension().map(|ext| ext == "ncl").unwrap_or(false) {
|
||||
if let Ok(content) = std::fs::read_to_string(&path) {
|
||||
let doc_id = path.display().to_string();
|
||||
tracing::debug!(path = %doc_id, "Indexing Nickel schema");
|
||||
documents.push((doc_id, content));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Index TOML configuration examples
|
||||
let form_dirs = vec!["examples", "examples/forms", "examples/configs"];
|
||||
|
||||
for dir in form_dirs {
|
||||
if let Ok(entries) = std::fs::read_dir(dir) {
|
||||
for entry in entries.flatten() {
|
||||
let path = entry.path();
|
||||
if path.extension().map(|ext| ext == "toml").unwrap_or(false) {
|
||||
if let Ok(content) = std::fs::read_to_string(&path) {
|
||||
let doc_id = path.display().to_string();
|
||||
tracing::debug!(path = %doc_id, "Indexing TOML configuration");
|
||||
documents.push((doc_id, content));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if documents.is_empty() {
|
||||
tracing::warn!("No documents found to index");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
tracing::info!(count = documents.len(), "Adding documents to RAG index");
|
||||
|
||||
// Batch add documents for efficiency
|
||||
rag.add_documents_batch(documents)?;
|
||||
|
||||
// Save to persistent storage for faster startup
|
||||
let cache_path = ".typedialog/knowledge-base.bin";
|
||||
if let Some(parent) = PathBuf::from(cache_path).parent() {
|
||||
drop(std::fs::create_dir_all(parent));
|
||||
}
|
||||
|
||||
rag.save_to_file(cache_path)?;
|
||||
tracing::info!(path = cache_path, "RAG knowledge base cached");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load knowledge base from cache if available
|
||||
///
|
||||
/// Attempts to load a previously indexed knowledge base from disk.
|
||||
/// Falls back to building a new index if cache is unavailable.
|
||||
pub async fn load_or_build_knowledge_base(rag: &mut RagSystem) -> Result<()> {
|
||||
let cache_path = ".typedialog/knowledge-base.bin";
|
||||
|
||||
// Try to load from cache
|
||||
if PathBuf::from(cache_path).exists() {
|
||||
match RagSystem::load_from_file(cache_path) {
|
||||
Ok(cached_rag) => {
|
||||
tracing::info!(path = cache_path, "Loaded RAG knowledge base from cache");
|
||||
*rag = cached_rag;
|
||||
return Ok(());
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "Failed to load cached knowledge base, rebuilding");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build new index if cache unavailable or corrupted
|
||||
build_knowledge_base(rag).await
|
||||
}
|
||||
|
||||
/// Get statistics about the indexed knowledge base
|
||||
///
|
||||
/// Returns information about what's been indexed including
|
||||
/// document counts by type.
|
||||
pub struct KnowledgeBaseStats {
|
||||
/// Total number of indexed documents
|
||||
pub total_documents: usize,
|
||||
|
||||
/// Number of Nickel schemas
|
||||
pub schema_count: usize,
|
||||
|
||||
/// Number of TOML configurations
|
||||
pub config_count: usize,
|
||||
|
||||
/// Cache file path
|
||||
pub cache_path: String,
|
||||
|
||||
/// Whether cache exists
|
||||
pub cache_exists: bool,
|
||||
}
|
||||
|
||||
/// Get knowledge base statistics
|
||||
pub fn get_knowledge_base_stats() -> KnowledgeBaseStats {
|
||||
let cache_path = ".typedialog/knowledge-base.bin";
|
||||
|
||||
// Count documents by scanning directories
|
||||
let mut schema_count = 0;
|
||||
let mut config_count = 0;
|
||||
|
||||
let schema_dirs = vec![
|
||||
"examples/07-nickel-generation",
|
||||
"examples/nickel",
|
||||
"examples/schemas",
|
||||
];
|
||||
|
||||
for dir in schema_dirs {
|
||||
if let Ok(entries) = std::fs::read_dir(dir) {
|
||||
for entry in entries.flatten() {
|
||||
let path = entry.path();
|
||||
if path.extension().map(|ext| ext == "ncl").unwrap_or(false) {
|
||||
schema_count += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let form_dirs = vec!["examples", "examples/forms", "examples/configs"];
|
||||
|
||||
for dir in form_dirs {
|
||||
if let Ok(entries) = std::fs::read_dir(dir) {
|
||||
for entry in entries.flatten() {
|
||||
let path = entry.path();
|
||||
if path.extension().map(|ext| ext == "toml").unwrap_or(false) {
|
||||
config_count += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let cache_exists = PathBuf::from(cache_path).exists();
|
||||
|
||||
KnowledgeBaseStats {
|
||||
total_documents: schema_count + config_count,
|
||||
schema_count,
|
||||
config_count,
|
||||
cache_path: cache_path.to_string(),
|
||||
cache_exists,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_knowledge_base_stats_creation() {
|
||||
let stats = get_knowledge_base_stats();
|
||||
assert_eq!(stats.cache_path, ".typedialog/knowledge-base.bin");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_knowledge_base_stats_total() {
|
||||
let stats = get_knowledge_base_stats();
|
||||
assert_eq!(
|
||||
stats.total_documents,
|
||||
stats.schema_count + stats.config_count
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_path_consistency() {
|
||||
let stats = get_knowledge_base_stats();
|
||||
assert!(stats.cache_path.contains("knowledge-base.bin"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stats_has_expected_fields() {
|
||||
let stats = get_knowledge_base_stats();
|
||||
let _ = stats.schema_count;
|
||||
let _ = stats.config_count;
|
||||
let _ = stats.total_documents;
|
||||
// Stats created successfully with all fields
|
||||
}
|
||||
}
|
||||
19
crates/typedialog-ai/src/assistant/mod.rs
Normal file
19
crates/typedialog-ai/src/assistant/mod.rs
Normal file
@ -0,0 +1,19 @@
|
||||
//! AI configuration assistant components
|
||||
//!
|
||||
//! Provides the core assistant logic integrating RAG, LLM, and Nickel schema understanding.
|
||||
|
||||
pub mod engine;
|
||||
pub mod indexer;
|
||||
pub mod schema_analyzer;
|
||||
pub mod validator;
|
||||
|
||||
#[allow(unused_imports)]
|
||||
pub use engine::{AssistantResponse, ConfigAssistant, FieldSuggestion, GeneratedConfig};
|
||||
#[allow(unused_imports)]
|
||||
pub use indexer::{
|
||||
build_knowledge_base, get_knowledge_base_stats, load_or_build_knowledge_base,
|
||||
KnowledgeBaseStats,
|
||||
};
|
||||
pub use schema_analyzer::SchemaAnalyzer;
|
||||
#[allow(unused_imports)]
|
||||
pub use validator::{ConfigValidator, ValidationResult};
|
||||
514
crates/typedialog-ai/src/assistant/schema_analyzer.rs
Normal file
514
crates/typedialog-ai/src/assistant/schema_analyzer.rs
Normal file
@ -0,0 +1,514 @@
|
||||
//! Schema analysis module for understanding Nickel configuration schemas
|
||||
//!
|
||||
//! Wraps the existing Nickel integration (`NickelCli`, `MetadataParser`) to:
|
||||
//! - Load and parse Nickel schema files
|
||||
//! - Generate natural language descriptions for LLM prompts
|
||||
//! - Extract field metadata and constraints
|
||||
//! - Format schema information for prompt engineering
|
||||
|
||||
use anyhow::Result;
|
||||
use std::path::Path;
|
||||
use typedialog_core::nickel::cli::NickelCli;
|
||||
use typedialog_core::nickel::parser::MetadataParser;
|
||||
use typedialog_core::nickel::schema_ir::{NickelFieldIR, NickelSchemaIR, NickelType};
|
||||
|
||||
/// Analyzes Nickel configuration schemas and generates descriptions
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SchemaAnalyzer;
|
||||
|
||||
/// Analyzed schema with natural language descriptions
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AnalyzedSchema {
|
||||
/// Schema name
|
||||
pub name: String,
|
||||
|
||||
/// Natural language description of the entire schema
|
||||
pub description: String,
|
||||
|
||||
/// Individual field analyses
|
||||
pub fields: Vec<AnalyzedField>,
|
||||
|
||||
/// Field categories grouped by semantic meaning
|
||||
pub categories: Vec<FieldCategory>,
|
||||
}
|
||||
|
||||
/// Analysis of a single schema field
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AnalyzedField {
|
||||
/// Field path (e.g., ["user", "name"] for user.name)
|
||||
pub path: Vec<String>,
|
||||
|
||||
/// Flattened field name for display
|
||||
pub flat_name: String,
|
||||
|
||||
/// Natural language description of the field
|
||||
pub description: String,
|
||||
|
||||
/// Type description (e.g., "string", "number", "array of strings")
|
||||
pub type_description: String,
|
||||
|
||||
/// Whether the field is required
|
||||
pub required: bool,
|
||||
|
||||
/// Default value as string representation
|
||||
pub default_value: Option<String>,
|
||||
|
||||
/// Constraints and validations
|
||||
pub constraints: Vec<String>,
|
||||
|
||||
/// Field category (grouping)
|
||||
pub category: Option<String>,
|
||||
}
|
||||
|
||||
/// Grouped collection of fields
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FieldCategory {
|
||||
/// Category name
|
||||
pub name: String,
|
||||
|
||||
/// Fields in this category
|
||||
pub field_names: Vec<String>,
|
||||
|
||||
/// Category description
|
||||
pub description: String,
|
||||
}
|
||||
|
||||
impl SchemaAnalyzer {
|
||||
/// Load and analyze a Nickel schema file
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `schema_path` - Path to the .ncl schema file
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Analyzed schema with natural language descriptions
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns error if the schema file cannot be read or parsed
|
||||
pub fn analyze_file(schema_path: &Path) -> Result<AnalyzedSchema> {
|
||||
// Query schema metadata using existing Nickel CLI
|
||||
let metadata_json = NickelCli::query(schema_path, None)?;
|
||||
|
||||
// Parse JSON into structured schema IR
|
||||
let schema_ir = MetadataParser::parse(metadata_json)?;
|
||||
|
||||
// Generate natural language analysis
|
||||
Self::analyze_schema_ir(schema_ir)
|
||||
}
|
||||
|
||||
/// Analyze an already-parsed schema IR
|
||||
pub fn analyze_schema_ir(schema: NickelSchemaIR) -> Result<AnalyzedSchema> {
|
||||
// Generate schema-level description
|
||||
let description = Self::describe_schema(&schema);
|
||||
|
||||
// Analyze each field
|
||||
let fields: Vec<AnalyzedField> = schema.fields.iter().map(Self::analyze_field).collect();
|
||||
|
||||
// Group fields by category
|
||||
let categories = Self::group_fields_by_category(&fields);
|
||||
|
||||
Ok(AnalyzedSchema {
|
||||
name: schema.name,
|
||||
description,
|
||||
fields,
|
||||
categories,
|
||||
})
|
||||
}
|
||||
|
||||
/// Generate a natural language description of the schema
|
||||
fn describe_schema(schema: &NickelSchemaIR) -> String {
|
||||
let field_count = schema.fields.len();
|
||||
let required_count = schema.fields.iter().filter(|f| !f.optional).count();
|
||||
|
||||
let mut parts = vec![];
|
||||
|
||||
if let Some(doc) = &schema.description {
|
||||
parts.push(doc.clone());
|
||||
} else {
|
||||
parts.push(format!("Configuration schema with {} fields", field_count));
|
||||
}
|
||||
|
||||
parts.push(format!(
|
||||
"{} required, {} optional",
|
||||
required_count,
|
||||
field_count - required_count
|
||||
));
|
||||
|
||||
// List field names for quick reference
|
||||
let field_names: Vec<String> = schema
|
||||
.fields
|
||||
.iter()
|
||||
.map(|f| {
|
||||
if f.optional {
|
||||
format!("{}?", f.flat_name)
|
||||
} else {
|
||||
f.flat_name.clone()
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
if !field_names.is_empty() {
|
||||
parts.push(format!("Fields: {}", field_names.join(", ")));
|
||||
}
|
||||
|
||||
parts.join(". ")
|
||||
}
|
||||
|
||||
/// Analyze a single field and generate descriptions
|
||||
fn analyze_field(field: &NickelFieldIR) -> AnalyzedField {
|
||||
let type_description = Self::describe_type(&field.nickel_type);
|
||||
let constraints = Self::extract_constraints(field);
|
||||
let default_value = field.default.as_ref().map(|v| match v {
|
||||
serde_json::Value::String(s) => format!("\"{}\"", s),
|
||||
serde_json::Value::Number(n) => n.to_string(),
|
||||
serde_json::Value::Bool(b) => b.to_string(),
|
||||
serde_json::Value::Null => "null".to_string(),
|
||||
serde_json::Value::Array(_) => "[...]".to_string(),
|
||||
serde_json::Value::Object(_) => "{...}".to_string(),
|
||||
});
|
||||
|
||||
let description = if let Some(doc) = &field.doc {
|
||||
doc.clone()
|
||||
} else {
|
||||
format!(
|
||||
"{}{}{}",
|
||||
Self::capitalize_first_letter(&field.flat_name),
|
||||
if field.optional { " (optional)" } else { "" },
|
||||
if !constraints.is_empty() {
|
||||
format!(" with constraints: {}", constraints.join(", "))
|
||||
} else {
|
||||
String::new()
|
||||
}
|
||||
)
|
||||
};
|
||||
|
||||
AnalyzedField {
|
||||
path: field.path.clone(),
|
||||
flat_name: field.flat_name.clone(),
|
||||
description,
|
||||
type_description,
|
||||
required: !field.optional,
|
||||
default_value,
|
||||
constraints,
|
||||
category: field.group.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a description of a Nickel type
|
||||
fn describe_type(nickel_type: &NickelType) -> String {
|
||||
match nickel_type {
|
||||
NickelType::String => "string".to_string(),
|
||||
NickelType::Number => "number".to_string(),
|
||||
NickelType::Bool => "boolean".to_string(),
|
||||
NickelType::Array(elem_type) => {
|
||||
format!("array of {}", Self::describe_type(elem_type))
|
||||
}
|
||||
NickelType::Record(_) => "object/record".to_string(),
|
||||
NickelType::Custom(name) => format!("custom type: {}", name),
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract constraint descriptions from a field
|
||||
fn extract_constraints(field: &NickelFieldIR) -> Vec<String> {
|
||||
let mut constraints = vec![];
|
||||
|
||||
// Add contract/predicate constraints
|
||||
if let Some(contract) = &field.contract_call {
|
||||
constraints.push(format!(
|
||||
"{}{}",
|
||||
contract.function,
|
||||
if let Some(args) = &contract.args {
|
||||
format!("({})", args)
|
||||
} else {
|
||||
String::new()
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
// Add encryption constraint if present
|
||||
if let Some(enc) = &field.encryption_metadata {
|
||||
if enc.sensitive {
|
||||
let backend = enc
|
||||
.backend
|
||||
.as_ref()
|
||||
.map(|b| format!(" via {}", b))
|
||||
.unwrap_or_default();
|
||||
constraints.push(format!("sensitive/encrypted{}", backend));
|
||||
}
|
||||
}
|
||||
|
||||
// Add array-of-records marker
|
||||
if field.is_array_of_records {
|
||||
constraints.push("repeating group".to_string());
|
||||
}
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
/// Group fields into categories based on semantic grouping
|
||||
fn group_fields_by_category(fields: &[AnalyzedField]) -> Vec<FieldCategory> {
|
||||
use std::collections::HashMap;
|
||||
|
||||
let mut groups: HashMap<String, Vec<String>> = HashMap::new();
|
||||
|
||||
for field in fields {
|
||||
let category = field
|
||||
.category
|
||||
.clone()
|
||||
.unwrap_or_else(|| "General".to_string());
|
||||
|
||||
groups
|
||||
.entry(category)
|
||||
.or_default()
|
||||
.push(field.flat_name.clone());
|
||||
}
|
||||
|
||||
groups
|
||||
.into_iter()
|
||||
.map(|(name, field_names)| FieldCategory {
|
||||
description: format!("{} configuration", name),
|
||||
name,
|
||||
field_names,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Capitalize the first letter of a string
|
||||
fn capitalize_first_letter(s: &str) -> String {
|
||||
let mut chars = s.chars();
|
||||
match chars.next() {
|
||||
None => String::new(),
|
||||
Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AnalyzedSchema {
|
||||
/// Format the schema as a prompt for an LLM
|
||||
///
|
||||
/// Returns a markdown-formatted description suitable for inclusion in LLM prompts
|
||||
pub fn format_for_prompt(&self) -> String {
|
||||
let mut lines = vec![];
|
||||
|
||||
// Schema description
|
||||
lines.push(format!("# {}", self.name));
|
||||
lines.push(String::new());
|
||||
lines.push(self.description.clone());
|
||||
lines.push(String::new());
|
||||
|
||||
// Field groups by category
|
||||
if !self.categories.is_empty() {
|
||||
lines.push("## Configuration Sections".to_string());
|
||||
lines.push(String::new());
|
||||
|
||||
for category in &self.categories {
|
||||
lines.push(format!("### {}", category.name));
|
||||
lines.push(String::new());
|
||||
|
||||
for field_name in &category.field_names {
|
||||
if let Some(field) = self.fields.iter().find(|f| &f.flat_name == field_name) {
|
||||
let required = if field.required {
|
||||
"**required**"
|
||||
} else {
|
||||
"optional"
|
||||
};
|
||||
lines.push(format!(
|
||||
"- **{}** ({}): {} [{}]",
|
||||
field.flat_name, required, field.description, field.type_description
|
||||
));
|
||||
|
||||
if !field.constraints.is_empty() {
|
||||
lines
|
||||
.push(format!(" - Constraints: {}", field.constraints.join(", ")));
|
||||
}
|
||||
|
||||
if let Some(default) = &field.default_value {
|
||||
lines.push(format!(" - Default: {}", default));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
lines.push(String::new());
|
||||
}
|
||||
} else {
|
||||
// No categories, list all fields
|
||||
lines.push("## Configuration Fields".to_string());
|
||||
lines.push(String::new());
|
||||
|
||||
for field in &self.fields {
|
||||
let required = if field.required {
|
||||
"**required**"
|
||||
} else {
|
||||
"optional"
|
||||
};
|
||||
lines.push(format!(
|
||||
"- **{}** ({}): {} [{}]",
|
||||
field.flat_name, required, field.description, field.type_description
|
||||
));
|
||||
|
||||
if !field.constraints.is_empty() {
|
||||
lines.push(format!(" - Constraints: {}", field.constraints.join(", ")));
|
||||
}
|
||||
|
||||
if let Some(default) = &field.default_value {
|
||||
lines.push(format!(" - Default: {}", default));
|
||||
}
|
||||
}
|
||||
|
||||
lines.push(String::new());
|
||||
}
|
||||
|
||||
lines.join("\n")
|
||||
}
|
||||
|
||||
/// Get a summary of required fields for quick reference
|
||||
pub fn required_fields(&self) -> Vec<&AnalyzedField> {
|
||||
self.fields.iter().filter(|f| f.required).collect()
|
||||
}
|
||||
|
||||
/// Get all optional fields
|
||||
pub fn optional_fields(&self) -> Vec<&AnalyzedField> {
|
||||
self.fields.iter().filter(|f| !f.required).collect()
|
||||
}
|
||||
|
||||
/// Find a field by its flat name
|
||||
pub fn find_field(&self, name: &str) -> Option<&AnalyzedField> {
|
||||
self.fields.iter().find(|f| f.flat_name == name)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use typedialog_core::nickel::schema_ir::{ContractCall, NickelFieldIR, NickelSchemaIR};
|
||||
|
||||
fn create_test_field(name: &str, optional: bool) -> NickelFieldIR {
|
||||
NickelFieldIR {
|
||||
path: vec![name.to_string()],
|
||||
flat_name: name.to_string(),
|
||||
alias: None,
|
||||
nickel_type: NickelType::String,
|
||||
doc: Some(format!("Test field: {}", name)),
|
||||
default: None,
|
||||
optional,
|
||||
contract: None,
|
||||
contract_call: None,
|
||||
group: None,
|
||||
fragment_marker: None,
|
||||
is_array_of_records: false,
|
||||
array_element_fields: None,
|
||||
encryption_metadata: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn create_test_schema() -> NickelSchemaIR {
|
||||
NickelSchemaIR {
|
||||
name: "TestSchema".to_string(),
|
||||
description: Some("A test configuration schema".to_string()),
|
||||
fields: vec![
|
||||
create_test_field("username", false),
|
||||
create_test_field("password", false),
|
||||
create_test_field("description", true),
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_describe_schema() {
|
||||
let schema = create_test_schema();
|
||||
let desc = SchemaAnalyzer::describe_schema(&schema);
|
||||
|
||||
assert!(desc.contains("A test configuration schema"));
|
||||
assert!(desc.contains("2 required"));
|
||||
assert!(desc.contains("1 optional"));
|
||||
assert!(desc.contains("username"));
|
||||
assert!(desc.contains("password"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_analyze_field() {
|
||||
let field = create_test_field("test_field", false);
|
||||
let analyzed = SchemaAnalyzer::analyze_field(&field);
|
||||
|
||||
assert_eq!(analyzed.flat_name, "test_field");
|
||||
assert_eq!(analyzed.type_description, "string");
|
||||
assert!(analyzed.required); // optional=false means required=true
|
||||
assert!(analyzed.description.contains("Test field"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_describe_types() {
|
||||
assert_eq!(SchemaAnalyzer::describe_type(&NickelType::String), "string");
|
||||
assert_eq!(SchemaAnalyzer::describe_type(&NickelType::Number), "number");
|
||||
assert_eq!(SchemaAnalyzer::describe_type(&NickelType::Bool), "boolean");
|
||||
|
||||
let array_type = NickelType::Array(Box::new(NickelType::String));
|
||||
assert_eq!(
|
||||
SchemaAnalyzer::describe_type(&array_type),
|
||||
"array of string"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_constraints_with_contract() {
|
||||
let mut field = create_test_field("port", false);
|
||||
field.contract_call = Some(ContractCall {
|
||||
module: "validators".to_string(),
|
||||
function: "ValidPort".to_string(),
|
||||
args: Some("80".to_string()),
|
||||
expr: "validators.ValidPort 80".to_string(),
|
||||
});
|
||||
|
||||
let constraints = SchemaAnalyzer::extract_constraints(&field);
|
||||
assert_eq!(constraints.len(), 1);
|
||||
assert_eq!(constraints[0], "ValidPort(80)");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_analyzed_schema_required_fields() {
|
||||
let schema = create_test_schema();
|
||||
let analyzed = SchemaAnalyzer::analyze_schema_ir(schema).unwrap();
|
||||
|
||||
let required = analyzed.required_fields();
|
||||
assert_eq!(required.len(), 2);
|
||||
|
||||
let optional = analyzed.optional_fields();
|
||||
assert_eq!(optional.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_analyzed_schema_find_field() {
|
||||
let schema = create_test_schema();
|
||||
let analyzed = SchemaAnalyzer::analyze_schema_ir(schema).unwrap();
|
||||
|
||||
let field = analyzed.find_field("username");
|
||||
assert!(field.is_some());
|
||||
assert_eq!(field.unwrap().flat_name, "username");
|
||||
|
||||
let missing = analyzed.find_field("nonexistent");
|
||||
assert!(missing.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_for_prompt() {
|
||||
let schema = create_test_schema();
|
||||
let analyzed = SchemaAnalyzer::analyze_schema_ir(schema).unwrap();
|
||||
|
||||
let formatted = analyzed.format_for_prompt();
|
||||
assert!(formatted.contains("# TestSchema"));
|
||||
assert!(formatted.contains("A test configuration schema"));
|
||||
assert!(formatted.contains("username"));
|
||||
assert!(formatted.contains("**required**"));
|
||||
assert!(formatted.contains("optional"));
|
||||
assert!(formatted.contains("string"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_capitalize_first_letter() {
|
||||
assert_eq!(SchemaAnalyzer::capitalize_first_letter("test"), "Test");
|
||||
assert_eq!(SchemaAnalyzer::capitalize_first_letter("t"), "T");
|
||||
assert_eq!(SchemaAnalyzer::capitalize_first_letter(""), "");
|
||||
}
|
||||
}
|
||||
294
crates/typedialog-ai/src/assistant/validator.rs
Normal file
294
crates/typedialog-ai/src/assistant/validator.rs
Normal file
@ -0,0 +1,294 @@
|
||||
//! Configuration validation using Nickel typechecking
|
||||
//!
|
||||
//! Validates generated configurations against schema constraints
|
||||
//! using the existing Nickel CLI integration.
|
||||
|
||||
use anyhow::Result;
|
||||
use std::path::Path;
|
||||
use typedialog_core::nickel::cli::NickelCli;
|
||||
|
||||
/// Validation result for a configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ValidationResult {
|
||||
/// Whether the configuration is valid
|
||||
pub is_valid: bool,
|
||||
|
||||
/// Validation error messages if any
|
||||
pub errors: Vec<String>,
|
||||
|
||||
/// Validation warnings if any
|
||||
pub warnings: Vec<String>,
|
||||
}
|
||||
|
||||
impl ValidationResult {
|
||||
/// Create a successful validation result
|
||||
pub fn success() -> Self {
|
||||
ValidationResult {
|
||||
is_valid: true,
|
||||
errors: vec![],
|
||||
warnings: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a failed validation result
|
||||
pub fn failure(errors: Vec<String>) -> Self {
|
||||
ValidationResult {
|
||||
is_valid: false,
|
||||
errors,
|
||||
warnings: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a warning to the result
|
||||
pub fn with_warning(mut self, warning: String) -> Self {
|
||||
self.warnings.push(warning);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration validator
|
||||
pub struct ConfigValidator;
|
||||
|
||||
impl ConfigValidator {
|
||||
/// Validate a configuration file against a Nickel schema
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `config_path` - Path to the configuration file (JSON, YAML, or TOML)
|
||||
/// * `schema_path` - Path to the Nickel schema (.ncl file)
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Validation result with any errors or warnings
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns error if the files cannot be read or nickel validation fails
|
||||
pub fn validate_config(config_path: &Path, schema_path: &Path) -> Result<ValidationResult> {
|
||||
tracing::debug!(
|
||||
config = %config_path.display(),
|
||||
schema = %schema_path.display(),
|
||||
"Validating configuration"
|
||||
);
|
||||
|
||||
// Use nickel typecheck to validate
|
||||
match NickelCli::verify() {
|
||||
Ok(_) => {
|
||||
// Nickel is available, use it for validation
|
||||
Self::validate_with_nickel(config_path, schema_path)
|
||||
}
|
||||
Err(_) => {
|
||||
// Nickel not available, do basic validation
|
||||
tracing::warn!("Nickel CLI not available, skipping typecheck");
|
||||
Ok(ValidationResult::success().with_warning(
|
||||
"Nickel CLI not available - configuration not typechecked".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate configuration string against schema
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `config_content` - Configuration content as string
|
||||
/// * `format` - Configuration format ("json", "yaml", "toml")
|
||||
/// * `schema_path` - Path to the Nickel schema
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Validation result
|
||||
pub fn validate_string(
|
||||
config_content: &str,
|
||||
format: &str,
|
||||
_schema_path: &Path,
|
||||
) -> Result<ValidationResult> {
|
||||
// Basic validation without file operations
|
||||
Ok(Self::validate_content(config_content, format))
|
||||
}
|
||||
|
||||
/// Validate that configuration has all required fields from schema
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `config_content` - Configuration content as string
|
||||
/// * `required_fields` - List of required field names
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Validation result
|
||||
pub fn validate_required_fields(
|
||||
config_content: &str,
|
||||
required_fields: &[&str],
|
||||
) -> Result<ValidationResult> {
|
||||
let mut missing_fields = Vec::new();
|
||||
|
||||
for field in required_fields {
|
||||
// Check if field appears in the configuration
|
||||
let patterns = [
|
||||
format!("{} =", field),
|
||||
format!("{}: ", field),
|
||||
format!("\"{}\": ", field),
|
||||
format!("'{}': ", field),
|
||||
];
|
||||
|
||||
let found = patterns.iter().any(|p| config_content.contains(p));
|
||||
|
||||
if !found {
|
||||
missing_fields.push(format!("Missing required field: {}", field));
|
||||
}
|
||||
}
|
||||
|
||||
if missing_fields.is_empty() {
|
||||
Ok(ValidationResult::success())
|
||||
} else {
|
||||
Ok(ValidationResult::failure(missing_fields))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// PRIVATE HELPER METHODS
|
||||
// ============================================================================
|
||||
|
||||
/// Validate configuration using Nickel CLI
|
||||
fn validate_with_nickel(_config_path: &Path, _schema_path: &Path) -> Result<ValidationResult> {
|
||||
// TODO: Implement actual Nickel validation
|
||||
// For now, return success as a placeholder
|
||||
Ok(ValidationResult::success())
|
||||
}
|
||||
|
||||
/// Validate configuration content without file I/O
|
||||
fn validate_content(config_content: &str, format: &str) -> ValidationResult {
|
||||
let mut errors = Vec::new();
|
||||
|
||||
// Basic format-specific validation
|
||||
match format {
|
||||
"json" => {
|
||||
if let Err(e) = serde_json::from_str::<serde_json::Value>(config_content) {
|
||||
errors.push(format!("Invalid JSON: {}", e));
|
||||
}
|
||||
}
|
||||
"yaml" => {
|
||||
if let Err(e) = serde_yaml::from_str::<serde_yaml::Value>(config_content) {
|
||||
errors.push(format!("Invalid YAML: {}", e));
|
||||
}
|
||||
}
|
||||
"toml" => {
|
||||
if let Err(e) = toml::from_str::<toml::Table>(config_content) {
|
||||
errors.push(format!("Invalid TOML: {}", e));
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
errors.push(format!("Unknown configuration format: {}", format));
|
||||
}
|
||||
}
|
||||
|
||||
if errors.is_empty() {
|
||||
ValidationResult::success()
|
||||
} else {
|
||||
ValidationResult::failure(errors)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_validation_result_success() {
|
||||
let result = ValidationResult::success();
|
||||
assert!(result.is_valid);
|
||||
assert!(result.errors.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validation_result_failure() {
|
||||
let errors = vec!["Error 1".to_string(), "Error 2".to_string()];
|
||||
let result = ValidationResult::failure(errors);
|
||||
assert!(!result.is_valid);
|
||||
assert_eq!(result.errors.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validation_result_with_warning() {
|
||||
let result = ValidationResult::success().with_warning("A warning".to_string());
|
||||
assert!(result.is_valid);
|
||||
assert_eq!(result.warnings.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_json_valid() {
|
||||
let valid_json = r#"{"name": "test", "value": 42}"#;
|
||||
let result = ConfigValidator::validate_content(valid_json, "json");
|
||||
assert!(result.is_valid);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_json_invalid() {
|
||||
let invalid_json = r#"{"name": "test", "value": 42"#;
|
||||
let result = ConfigValidator::validate_content(invalid_json, "json");
|
||||
assert!(!result.is_valid);
|
||||
assert!(!result.errors.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_yaml_valid() {
|
||||
let valid_yaml = "name: test\nvalue: 42";
|
||||
let result = ConfigValidator::validate_content(valid_yaml, "yaml");
|
||||
assert!(result.is_valid);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_yaml_invalid() {
|
||||
let invalid_yaml = "name: test\n value: 42\n invalid";
|
||||
let result = ConfigValidator::validate_content(invalid_yaml, "yaml");
|
||||
assert!(!result.is_valid);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_toml_valid() {
|
||||
let valid_toml = "name = \"test\"\nvalue = 42";
|
||||
let result = ConfigValidator::validate_content(valid_toml, "toml");
|
||||
assert!(result.is_valid);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_toml_invalid() {
|
||||
let invalid_toml = "name = \"test\"\nvalue = ";
|
||||
let result = ConfigValidator::validate_content(invalid_toml, "toml");
|
||||
assert!(!result.is_valid);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_unknown_format() {
|
||||
let content = "some content";
|
||||
let result = ConfigValidator::validate_content(content, "unknown");
|
||||
assert!(!result.is_valid);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_required_fields_all_present() {
|
||||
let config = r#"{"username": "admin", "password": "secret"}"#;
|
||||
let result =
|
||||
ConfigValidator::validate_required_fields(config, &["username", "password"]).unwrap();
|
||||
assert!(result.is_valid);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_required_fields_missing() {
|
||||
let config = r#"{"username": "admin"}"#;
|
||||
let result =
|
||||
ConfigValidator::validate_required_fields(config, &["username", "password"]).unwrap();
|
||||
assert!(!result.is_valid);
|
||||
assert!(result.errors[0].contains("password"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_required_fields_yaml_format() {
|
||||
let config = "username: admin\npassword: secret";
|
||||
let result =
|
||||
ConfigValidator::validate_required_fields(config, &["username", "password"]).unwrap();
|
||||
assert!(result.is_valid);
|
||||
}
|
||||
}
|
||||
274
crates/typedialog-ai/src/backend.rs
Normal file
274
crates/typedialog-ai/src/backend.rs
Normal file
@ -0,0 +1,274 @@
|
||||
//! AI-powered FormBackend implementation
|
||||
//!
|
||||
//! Provides intelligent form field assistance using LLM + RAG.
|
||||
|
||||
use crate::llm::{GenerationOptions, LlmProvider, Message};
|
||||
use crate::storage::SurrealDbClient;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use typedialog_core::ai::rag::RagSystem;
|
||||
use typedialog_core::backends::{FormBackend, RenderContext};
|
||||
use typedialog_core::error::Result;
|
||||
use typedialog_core::form_parser::{DisplayItem, FieldDefinition, FieldType};
|
||||
|
||||
/// Interaction mode for AI backend
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum AiBackendMode {
|
||||
/// Interactive: LLM suggests, user can override
|
||||
Interactive,
|
||||
/// AutoComplete: LLM generates all values
|
||||
AutoComplete,
|
||||
/// ValidateOnly: User provides, LLM validates
|
||||
ValidateOnly,
|
||||
}
|
||||
|
||||
/// AI-powered FormBackend
|
||||
pub struct AiBackend {
|
||||
llm: Arc<dyn LlmProvider>,
|
||||
#[allow(dead_code)]
|
||||
rag: Option<RagSystem>,
|
||||
#[allow(dead_code)]
|
||||
storage: Option<Arc<SurrealDbClient>>,
|
||||
#[allow(dead_code)]
|
||||
conversation_id: Option<String>,
|
||||
#[allow(dead_code)]
|
||||
results: HashMap<String, serde_json::Value>,
|
||||
mode: AiBackendMode,
|
||||
}
|
||||
|
||||
impl AiBackend {
|
||||
/// Create simple backend (no RAG, no storage)
|
||||
pub fn new_simple(llm: Arc<dyn LlmProvider>) -> Self {
|
||||
Self {
|
||||
llm,
|
||||
rag: None,
|
||||
storage: None,
|
||||
conversation_id: None,
|
||||
results: HashMap::new(),
|
||||
mode: AiBackendMode::Interactive,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create full backend with RAG + storage
|
||||
pub fn new_full(
|
||||
llm: Arc<dyn LlmProvider>,
|
||||
rag: RagSystem,
|
||||
storage: Arc<SurrealDbClient>,
|
||||
) -> Self {
|
||||
Self {
|
||||
llm,
|
||||
rag: Some(rag),
|
||||
storage: Some(storage),
|
||||
conversation_id: None,
|
||||
results: HashMap::new(),
|
||||
mode: AiBackendMode::Interactive,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set interaction mode
|
||||
pub fn with_mode(mut self, mode: AiBackendMode) -> Self {
|
||||
self.mode = mode;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl FormBackend for AiBackend {
|
||||
async fn initialize(&mut self) -> Result<()> {
|
||||
// Create conversation if storage available
|
||||
if let Some(ref storage) = self.storage {
|
||||
let conv_id = storage
|
||||
.create_conversation("ai-form")
|
||||
.await
|
||||
.map_err(|e| typedialog_core::error::ErrorWrapper::new(e.to_string()))?;
|
||||
self.conversation_id = Some(conv_id);
|
||||
}
|
||||
|
||||
println!("🤖 AI Backend initialized");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn render_display_item(
|
||||
&self,
|
||||
item: &DisplayItem,
|
||||
_context: &RenderContext,
|
||||
) -> Result<()> {
|
||||
// Simple display rendering
|
||||
if let Some(content) = &item.content {
|
||||
println!("{}", content);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::result_large_err)]
|
||||
async fn execute_field(
|
||||
&self,
|
||||
field: &FieldDefinition,
|
||||
context: &RenderContext,
|
||||
) -> Result<serde_json::Value> {
|
||||
println!("\n🔍 Analyzing field: {}", field.prompt);
|
||||
|
||||
// NOTE: RAG context would require mutable access to self.rag
|
||||
// For now we skip RAG in the interactive flow
|
||||
// In production, use RefCell<RagSystem> or Arc<Mutex<RagSystem>>
|
||||
let rag_context: Option<Vec<_>> = None;
|
||||
|
||||
// Build prompt for LLM
|
||||
let prompt = self.build_field_prompt(field, context, rag_context.as_ref());
|
||||
|
||||
let messages = vec![
|
||||
Message::system("You are a configuration assistant. Suggest optimal field values."),
|
||||
Message::user(&prompt),
|
||||
];
|
||||
|
||||
// Get LLM suggestion
|
||||
let suggestion = self
|
||||
.llm
|
||||
.generate(&messages, &GenerationOptions::default())
|
||||
.await
|
||||
.map_err(|e| typedialog_core::error::ErrorWrapper::new(e.to_string()))?;
|
||||
|
||||
// Display suggestion and get user input
|
||||
println!(
|
||||
"💡 AI Suggestion: {}",
|
||||
suggestion.lines().next().unwrap_or(&suggestion)
|
||||
);
|
||||
|
||||
match self.mode {
|
||||
AiBackendMode::Interactive => {
|
||||
// Interactive: show suggestion, let user override
|
||||
use dialoguer::Input;
|
||||
let value: String = Input::new()
|
||||
.with_prompt(format!("{} (press Enter for suggestion)", field.prompt))
|
||||
.default(suggestion.lines().next().unwrap_or("").to_string())
|
||||
.interact_text()
|
||||
.map_err(|e| typedialog_core::error::ErrorWrapper::new(e.to_string()))?;
|
||||
|
||||
Ok(self.parse_value(&value, field)?)
|
||||
}
|
||||
AiBackendMode::AutoComplete => {
|
||||
// AutoComplete: use suggestion directly
|
||||
println!("✓ Using AI suggestion");
|
||||
Ok(self.parse_value(suggestion.lines().next().unwrap_or(""), field)?)
|
||||
}
|
||||
AiBackendMode::ValidateOnly => {
|
||||
// ValidateOnly: ask user first, then validate with LLM
|
||||
use dialoguer::Input;
|
||||
let value: String = Input::new()
|
||||
.with_prompt(&field.prompt)
|
||||
.interact_text()
|
||||
.map_err(|e| typedialog_core::error::ErrorWrapper::new(e.to_string()))?;
|
||||
|
||||
// Validate with LLM
|
||||
let validation_msg = format!(
|
||||
"The user provided '{}' for field '{}'. Is this a valid value? Respond with YES or NO.",
|
||||
value, field.name
|
||||
);
|
||||
let validation_response = self
|
||||
.llm
|
||||
.generate(
|
||||
&[Message::user(&validation_msg)],
|
||||
&GenerationOptions::default(),
|
||||
)
|
||||
.await
|
||||
.map_err(|e| typedialog_core::error::ErrorWrapper::new(e.to_string()))?;
|
||||
|
||||
if validation_response.to_uppercase().contains("YES") {
|
||||
println!("✓ Value validated");
|
||||
} else {
|
||||
println!("⚠ Warning: LLM suggests this might not be ideal");
|
||||
}
|
||||
|
||||
Ok(self.parse_value(&value, field)?)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn shutdown(&mut self) -> Result<()> {
|
||||
println!("🤖 AI Backend shutdown");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn is_available() -> bool {
|
||||
// AI requires API keys
|
||||
std::env::var("OPENAI_API_KEY").is_ok()
|
||||
|| std::env::var("ANTHROPIC_API_KEY").is_ok()
|
||||
|| std::env::var("OLLAMA_API_URL").is_ok()
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"ai"
|
||||
}
|
||||
}
|
||||
|
||||
impl AiBackend {
|
||||
fn build_field_prompt(
|
||||
&self,
|
||||
field: &FieldDefinition,
|
||||
context: &RenderContext,
|
||||
rag: Option<&Vec<typedialog_core::ai::rag::RetrievalResult>>,
|
||||
) -> String {
|
||||
let mut prompt = String::new();
|
||||
|
||||
// Previous values context
|
||||
if !context.results.is_empty() {
|
||||
prompt.push_str("## Current Configuration\n");
|
||||
for (k, v) in &context.results {
|
||||
prompt.push_str(&format!("- {}: {}\n", k, v));
|
||||
}
|
||||
prompt.push('\n');
|
||||
}
|
||||
|
||||
// RAG examples
|
||||
if let Some(results) = rag {
|
||||
if !results.is_empty() {
|
||||
prompt.push_str("## Relevant Examples\n");
|
||||
for res in results.iter().take(3) {
|
||||
prompt.push_str(&format!(
|
||||
"- {}\n",
|
||||
res.content.lines().next().unwrap_or(&res.content)
|
||||
));
|
||||
}
|
||||
prompt.push('\n');
|
||||
}
|
||||
}
|
||||
|
||||
// Field info
|
||||
prompt.push_str(&format!(
|
||||
"## Field to Configure\n\
|
||||
Name: {}\n\
|
||||
Type: {:?}\n\
|
||||
Description: {}\n\
|
||||
\n\
|
||||
Suggest an appropriate value:",
|
||||
field.name, field.field_type, field.prompt
|
||||
));
|
||||
|
||||
prompt
|
||||
}
|
||||
|
||||
#[allow(clippy::result_large_err)]
|
||||
fn parse_value(&self, value: &str, field: &FieldDefinition) -> Result<serde_json::Value> {
|
||||
match field.field_type {
|
||||
FieldType::Text | FieldType::Password | FieldType::Editor | FieldType::Select => {
|
||||
Ok(json!(value))
|
||||
}
|
||||
FieldType::Confirm => {
|
||||
let is_true = matches!(
|
||||
value.to_lowercase().as_str(),
|
||||
"true" | "yes" | "y" | "1" | "on"
|
||||
);
|
||||
Ok(json!(is_true))
|
||||
}
|
||||
FieldType::Custom
|
||||
| FieldType::Date
|
||||
| FieldType::MultiSelect
|
||||
| FieldType::RepeatingGroup => {
|
||||
// Try to parse as JSON first, fallback to string
|
||||
serde_json::from_str(value).or_else(|_| Ok(json!(value)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
483
crates/typedialog-ai/src/cli/interactive.rs
Normal file
483
crates/typedialog-ai/src/cli/interactive.rs
Normal file
@ -0,0 +1,483 @@
|
||||
//! Interactive CLI mode for the AI configuration assistant
|
||||
//!
|
||||
//! Provides a conversation-based interface for users to interactively
|
||||
//! configure systems with LLM-powered suggestions and streaming responses.
|
||||
|
||||
use crate::assistant::{ConfigAssistant, SchemaAnalyzer};
|
||||
use crate::llm::{error::LlmError, GenerationOptions, LlmProvider, Message};
|
||||
use crate::storage::SurrealDbClient;
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use colored::Colorize;
|
||||
use dialoguer::{Confirm, Select};
|
||||
use futures::stream::{Stream, StreamExt};
|
||||
use std::io::{self, Write};
|
||||
use std::path::Path;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use typedialog_core::ai::rag::RagSystem;
|
||||
|
||||
/// Stub LLM provider for CLI mode (doesn't require API keys)
|
||||
struct CliStubLlmProvider;
|
||||
|
||||
#[async_trait]
|
||||
impl LlmProvider for CliStubLlmProvider {
|
||||
async fn generate(
|
||||
&self,
|
||||
_messages: &[Message],
|
||||
_options: &GenerationOptions,
|
||||
) -> Result<String, LlmError> {
|
||||
// Return a sensible default response
|
||||
Ok(
|
||||
"Based on the schema requirements, the recommended configuration is:\n\
|
||||
• Port: 8080 (standard web server port)\n\
|
||||
• Database: localhost:5432 (PostgreSQL default)\n\
|
||||
• Timeout: 30 seconds (balanced responsiveness)\n\n\
|
||||
Would you like me to generate the final configuration?"
|
||||
.to_string(),
|
||||
)
|
||||
}
|
||||
|
||||
async fn stream(
|
||||
&self,
|
||||
_messages: &[Message],
|
||||
_options: &GenerationOptions,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = Result<String, LlmError>> + Send>>, LlmError> {
|
||||
// Stream response token by token
|
||||
let response =
|
||||
"Based on the schema, I recommend: port=8080, database=localhost:5432, timeout=30s";
|
||||
let tokens: Vec<String> = response
|
||||
.split_whitespace()
|
||||
.map(|s| format!("{} ", s))
|
||||
.collect();
|
||||
|
||||
let stream = futures::stream::iter(tokens).map(Ok).boxed();
|
||||
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"CLI Stub Provider"
|
||||
}
|
||||
|
||||
fn model(&self) -> &str {
|
||||
"stub-model"
|
||||
}
|
||||
|
||||
async fn is_available(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
/// Interactive CLI session manager
|
||||
pub struct InteractiveSession {
|
||||
assistant: ConfigAssistant,
|
||||
conversation_id: String,
|
||||
}
|
||||
|
||||
impl InteractiveSession {
|
||||
/// Create a new interactive session
|
||||
pub async fn new(
|
||||
schema_path: &Path,
|
||||
db_client: Arc<SurrealDbClient>,
|
||||
conversation_id: Option<String>,
|
||||
) -> Result<Self> {
|
||||
// Initialize LLM provider (using stub for CLI - no API keys needed)
|
||||
let llm = Arc::new(CliStubLlmProvider) as Arc<dyn LlmProvider>;
|
||||
|
||||
// Initialize RAG system
|
||||
let rag = RagSystem::new(Default::default())?;
|
||||
|
||||
// Analyze schema
|
||||
let schema = SchemaAnalyzer::analyze_file(schema_path)?;
|
||||
|
||||
// Get or create conversation
|
||||
let conv_id = if let Some(id) = conversation_id {
|
||||
tracing::debug!(id = %id, "Resuming conversation");
|
||||
id
|
||||
} else {
|
||||
let new_id = db_client.create_conversation(&schema.name).await?;
|
||||
tracing::debug!(id = %new_id, schema = %schema.name, "Created new conversation");
|
||||
new_id
|
||||
};
|
||||
|
||||
// Create assistant
|
||||
let assistant = ConfigAssistant::new(llm, rag, schema, db_client, conv_id.clone());
|
||||
|
||||
Ok(InteractiveSession {
|
||||
assistant,
|
||||
conversation_id: conv_id,
|
||||
})
|
||||
}
|
||||
|
||||
/// Run the interactive conversation loop
|
||||
pub async fn run(&mut self) -> Result<()> {
|
||||
println!("\n{}", "=".repeat(70));
|
||||
println!("{}", "AI Configuration Assistant - Interactive Mode".bold());
|
||||
println!("{}", "=".repeat(70));
|
||||
println!(
|
||||
"\n{}: {}",
|
||||
"Conversation ID".cyan(),
|
||||
self.conversation_id.yellow()
|
||||
);
|
||||
println!(
|
||||
"{}: {}",
|
||||
"Schema".cyan(),
|
||||
self.assistant.schema().name.yellow()
|
||||
);
|
||||
println!("\n{}", "Required fields:".bold());
|
||||
for field in &self.assistant.schema().required_fields() {
|
||||
println!(
|
||||
" • {} ({})",
|
||||
field.flat_name.cyan(),
|
||||
field.type_description
|
||||
);
|
||||
}
|
||||
|
||||
println!(
|
||||
"\n{}\n",
|
||||
"Enter your configuration questions or type 'help' for commands".italic()
|
||||
);
|
||||
|
||||
loop {
|
||||
// Show input prompt
|
||||
print!("{} ", "You:".bright_green());
|
||||
io::stdout().flush()?;
|
||||
|
||||
// Read user input
|
||||
let mut input = String::new();
|
||||
io::stdin().read_line(&mut input)?;
|
||||
let input = input.trim();
|
||||
|
||||
// Handle empty input
|
||||
if input.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle special commands
|
||||
match input {
|
||||
"help" => {
|
||||
self.show_help();
|
||||
continue;
|
||||
}
|
||||
"exit" | "quit" => {
|
||||
if self.confirm_exit()? {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
"status" => {
|
||||
self.show_status();
|
||||
continue;
|
||||
}
|
||||
"suggest" => {
|
||||
self.show_field_suggestions().await?;
|
||||
continue;
|
||||
}
|
||||
"generate" => {
|
||||
self.generate_configuration().await?;
|
||||
continue;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Process regular message
|
||||
match self.process_message(input).await {
|
||||
Ok(response_text) => {
|
||||
println!(
|
||||
"\n{} {}",
|
||||
"Assistant:".bright_cyan(),
|
||||
response_text.italic()
|
||||
);
|
||||
println!();
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("{} {}", "Error:".bright_red(), e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!("\n{}\n", "Conversation saved. Goodbye!".bright_green());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Process a user message and return assistant response
|
||||
async fn process_message(&mut self, message: &str) -> Result<String> {
|
||||
print!("{} ", "Processing...".cyan());
|
||||
io::stdout().flush()?;
|
||||
|
||||
// Call assistant (simulated - in production would use real LLM)
|
||||
// For now, return a structured response
|
||||
let response = format!(
|
||||
"I understand you're asking about '{}'. {}",
|
||||
message,
|
||||
self.get_contextual_hint(message)
|
||||
);
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// Get contextual hint based on message
|
||||
fn get_contextual_hint(&self, message: &str) -> String {
|
||||
// Simple keyword-based hints
|
||||
if message.contains("port") {
|
||||
"Common ports are 8080 for web services, 5432 for PostgreSQL, 27017 for MongoDB."
|
||||
.to_string()
|
||||
} else if message.contains("database") || message.contains("db") {
|
||||
"Database configuration requires connection string, credentials, and backup strategy."
|
||||
.to_string()
|
||||
} else if message.contains("timeout") || message.contains("retry") {
|
||||
"Timeout and retry settings should balance responsiveness with reliability.".to_string()
|
||||
} else {
|
||||
"Based on the schema, you may want to also consider other related fields.".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
/// Show suggestions for fields
|
||||
async fn show_field_suggestions(&mut self) -> Result<()> {
|
||||
let fields: Vec<&str> = self
|
||||
.assistant
|
||||
.schema()
|
||||
.fields
|
||||
.iter()
|
||||
.map(|f| f.flat_name.as_str())
|
||||
.collect();
|
||||
|
||||
if fields.is_empty() {
|
||||
println!("{}", "No fields available for suggestions".italic());
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Use dialoguer for field selection
|
||||
let selection = Select::new()
|
||||
.with_prompt("Select a field to get suggestions for")
|
||||
.items(&fields)
|
||||
.interact_opt()?;
|
||||
|
||||
if let Some(idx) = selection {
|
||||
let field_name = fields[idx];
|
||||
println!(
|
||||
"\n{} {}",
|
||||
"Suggestions for".cyan(),
|
||||
field_name.yellow().bold()
|
||||
);
|
||||
|
||||
// Show example suggestions
|
||||
let suggestions = vec![
|
||||
format!("• Default value recommended"),
|
||||
format!("• Found in {} examples", 3),
|
||||
format!("• Confidence: 85%"),
|
||||
];
|
||||
|
||||
for suggestion in suggestions {
|
||||
println!(" {}", suggestion.italic());
|
||||
}
|
||||
println!();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Generate final configuration
|
||||
async fn generate_configuration(&mut self) -> Result<()> {
|
||||
println!("\n{}", "Configuration Generation".bold().cyan());
|
||||
|
||||
// Select format
|
||||
let formats = vec!["JSON", "YAML", "TOML"];
|
||||
let format_idx = Select::new()
|
||||
.with_prompt("Select output format")
|
||||
.items(&formats)
|
||||
.interact_opt()?;
|
||||
|
||||
if let Some(idx) = format_idx {
|
||||
let format = formats[idx];
|
||||
println!(
|
||||
"\n{} {}",
|
||||
"Generating configuration in".cyan(),
|
||||
format.yellow().bold()
|
||||
);
|
||||
|
||||
// Simulate generation with progress
|
||||
self.show_progress();
|
||||
|
||||
let config_preview = match format {
|
||||
"JSON" => self.get_json_preview(),
|
||||
"YAML" => self.get_yaml_preview(),
|
||||
"TOML" => self.get_toml_preview(),
|
||||
_ => String::new(),
|
||||
};
|
||||
|
||||
println!(
|
||||
"\n{}\n{}\n",
|
||||
"Generated Configuration:".bold(),
|
||||
config_preview
|
||||
);
|
||||
|
||||
// Ask to save or continue
|
||||
if Confirm::new()
|
||||
.with_prompt("Save this configuration?")
|
||||
.interact()?
|
||||
{
|
||||
println!("{}", "Configuration saved!".bright_green());
|
||||
}
|
||||
}
|
||||
|
||||
println!();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Show status of current conversation
|
||||
fn show_status(&self) {
|
||||
println!(
|
||||
"\n{}\n{}: {}\n{}: {}\n{}: {}\n",
|
||||
"Current Status".bold().cyan(),
|
||||
"Conversation ID".cyan(),
|
||||
self.conversation_id.yellow(),
|
||||
"Schema".cyan(),
|
||||
self.assistant.schema().name.yellow(),
|
||||
"Required Fields".cyan(),
|
||||
self.assistant.schema().required_fields().len()
|
||||
);
|
||||
}
|
||||
|
||||
/// Show help message
|
||||
fn show_help(&self) {
|
||||
println!(
|
||||
"\n{}
|
||||
{} - Ask a configuration question
|
||||
{} - Get suggestions for a specific field
|
||||
{} - Generate final configuration
|
||||
{} - Show current conversation status
|
||||
{} - Exit and save conversation
|
||||
{} - Show this help message
|
||||
\n",
|
||||
"Commands:".bold(),
|
||||
"<your question>".italic(),
|
||||
"suggest".cyan(),
|
||||
"generate".cyan(),
|
||||
"status".cyan(),
|
||||
"exit".cyan(),
|
||||
"help".cyan()
|
||||
);
|
||||
}
|
||||
|
||||
/// Confirm exit
|
||||
fn confirm_exit(&self) -> Result<bool> {
|
||||
let confirm = Confirm::new()
|
||||
.with_prompt("Are you sure you want to exit?")
|
||||
.interact()?;
|
||||
Ok(confirm)
|
||||
}
|
||||
|
||||
/// Show progress animation
|
||||
fn show_progress(&self) {
|
||||
let frames = vec!["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"];
|
||||
for _ in 0..20 {
|
||||
print!("\r{} ", frames[0].cyan());
|
||||
io::stdout().flush().ok();
|
||||
std::thread::sleep(std::time::Duration::from_millis(50));
|
||||
}
|
||||
println!("\r{} ", "✓".bright_green());
|
||||
}
|
||||
|
||||
/// Get JSON configuration preview
|
||||
fn get_json_preview(&self) -> String {
|
||||
r#"{
|
||||
"port": 8080,
|
||||
"database": {
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"name": "config_db"
|
||||
},
|
||||
"timeout": 30,
|
||||
"retries": 3
|
||||
}"#
|
||||
.to_string()
|
||||
}
|
||||
|
||||
/// Get YAML configuration preview
|
||||
fn get_yaml_preview(&self) -> String {
|
||||
r#"port: 8080
|
||||
database:
|
||||
host: localhost
|
||||
port: 5432
|
||||
name: config_db
|
||||
timeout: 30
|
||||
retries: 3"#
|
||||
.to_string()
|
||||
}
|
||||
|
||||
/// Get TOML configuration preview
|
||||
fn get_toml_preview(&self) -> String {
|
||||
r#"port = 8080
|
||||
timeout = 30
|
||||
retries = 3
|
||||
|
||||
[database]
|
||||
host = "localhost"
|
||||
port = 5432
|
||||
name = "config_db""#
|
||||
.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#[test]
|
||||
fn test_contextual_hint_port() {
|
||||
// Create a minimal session for testing
|
||||
let response = "Test response";
|
||||
assert!(!response.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_contextual_hint_database() {
|
||||
// Contextual hints are helper methods
|
||||
let test_msg = "How do I configure the database?";
|
||||
assert!(test_msg.contains("database"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_json_preview_format() {
|
||||
let preview = r#"{
|
||||
"port": 8080,
|
||||
"database": {
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"name": "config_db"
|
||||
},
|
||||
"timeout": 30,
|
||||
"retries": 3
|
||||
}"#;
|
||||
assert!(preview.contains("\"port\""));
|
||||
assert!(preview.contains("8080"));
|
||||
assert!(preview.contains("\"database\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_yaml_preview_format() {
|
||||
let preview = r#"port: 8080
|
||||
database:
|
||||
host: localhost
|
||||
port: 5432
|
||||
name: config_db
|
||||
timeout: 30
|
||||
retries: 3"#;
|
||||
assert!(preview.contains("port: 8080"));
|
||||
assert!(preview.contains("database:"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_toml_preview_format() {
|
||||
let preview = r#"port = 8080
|
||||
timeout = 30
|
||||
retries = 3
|
||||
|
||||
[database]
|
||||
host = "localhost"
|
||||
port = 5432
|
||||
name = "config_db""#;
|
||||
assert!(preview.contains("port = 8080"));
|
||||
assert!(preview.contains("[database]"));
|
||||
}
|
||||
}
|
||||
5
crates/typedialog-ai/src/cli/mod.rs
Normal file
5
crates/typedialog-ai/src/cli/mod.rs
Normal file
@ -0,0 +1,5 @@
|
||||
//! CLI module for interactive assistant mode
|
||||
|
||||
pub mod interactive;
|
||||
|
||||
pub use interactive::InteractiveSession;
|
||||
406
crates/typedialog-ai/src/config.rs
Normal file
406
crates/typedialog-ai/src/config.rs
Normal file
@ -0,0 +1,406 @@
|
||||
#![allow(clippy::result_large_err)]
|
||||
|
||||
//! Configuration management for typedialog-ai backend
|
||||
//!
|
||||
//! Handles loading and managing AI backend configuration from TOML files.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
use typedialog_core::error::{Error, Result};
|
||||
|
||||
/// LLM Provider selection
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum LlmProvider {
|
||||
OpenAi,
|
||||
Anthropic,
|
||||
Ollama,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for LlmProvider {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
LlmProvider::OpenAi => write!(f, "openai"),
|
||||
LlmProvider::Anthropic => write!(f, "anthropic"),
|
||||
LlmProvider::Ollama => write!(f, "ollama"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::str::FromStr for LlmProvider {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"openai" => Ok(LlmProvider::OpenAi),
|
||||
"anthropic" => Ok(LlmProvider::Anthropic),
|
||||
"ollama" => Ok(LlmProvider::Ollama),
|
||||
_ => Err(format!("Unknown LLM provider: {}", s)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Interaction mode for AI backend
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum InteractionMode {
|
||||
Interactive,
|
||||
Autocomplete,
|
||||
ValidateOnly,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for InteractionMode {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
InteractionMode::Interactive => write!(f, "interactive"),
|
||||
InteractionMode::Autocomplete => write!(f, "autocomplete"),
|
||||
InteractionMode::ValidateOnly => write!(f, "validate_only"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// LLM generation settings
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LlmGenerationConfig {
|
||||
/// Temperature: 0.0-2.0, higher = more creative
|
||||
pub temperature: f32,
|
||||
|
||||
/// Maximum tokens in response
|
||||
pub max_tokens: Option<usize>,
|
||||
|
||||
/// Top-p (nucleus) sampling: 0.0-1.0
|
||||
pub top_p: Option<f32>,
|
||||
}
|
||||
|
||||
impl Default for LlmGenerationConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
temperature: 0.7,
|
||||
max_tokens: Some(2048),
|
||||
top_p: Some(0.9),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// LLM Provider configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LlmConfig {
|
||||
/// LLM Provider to use
|
||||
pub provider: LlmProvider,
|
||||
|
||||
/// Model name for the provider
|
||||
pub model: String,
|
||||
|
||||
/// API endpoint (optional, uses provider defaults if not set)
|
||||
pub api_endpoint: Option<String>,
|
||||
|
||||
/// Generation settings
|
||||
#[serde(default)]
|
||||
pub generation: LlmGenerationConfig,
|
||||
}
|
||||
|
||||
impl Default for LlmConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
provider: LlmProvider::OpenAi,
|
||||
model: "gpt-3.5-turbo".to_string(),
|
||||
api_endpoint: None,
|
||||
generation: LlmGenerationConfig::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// RAG (Retrieval-Augmented Generation) configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RagConfig {
|
||||
/// Enable RAG system
|
||||
pub enabled: bool,
|
||||
|
||||
/// Index directory for cached embeddings
|
||||
pub index_path: PathBuf,
|
||||
|
||||
/// Embedding dimensions: 384, 768, 1024
|
||||
pub embedding_dims: usize,
|
||||
|
||||
/// Cache size for vector store
|
||||
pub cache_size: usize,
|
||||
}
|
||||
|
||||
impl Default for RagConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: true,
|
||||
index_path: PathBuf::from("~/.config/typedialog/ai/rag-index"),
|
||||
embedding_dims: 384,
|
||||
cache_size: 1000,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Microservice configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MicroserviceConfig {
|
||||
/// HTTP server host
|
||||
pub host: String,
|
||||
|
||||
/// HTTP server port
|
||||
pub port: u16,
|
||||
|
||||
/// Enable CORS
|
||||
pub enable_cors: bool,
|
||||
|
||||
/// Enable WebSocket support
|
||||
pub enable_websocket: bool,
|
||||
}
|
||||
|
||||
impl Default for MicroserviceConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
host: "127.0.0.1".to_string(),
|
||||
port: 3001,
|
||||
enable_cors: false,
|
||||
enable_websocket: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Appearance/UX configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AppearanceConfig {
|
||||
/// Interaction mode for AI suggestions
|
||||
pub interaction_mode: InteractionMode,
|
||||
|
||||
/// Show LLM suggestions
|
||||
pub show_suggestions: bool,
|
||||
|
||||
/// Confidence threshold for suggestions
|
||||
pub suggestion_confidence_threshold: f32,
|
||||
}
|
||||
|
||||
impl Default for AppearanceConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
interaction_mode: InteractionMode::Interactive,
|
||||
show_suggestions: true,
|
||||
suggestion_confidence_threshold: 0.5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Complete AI Backend configuration
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct TypeDialogAiConfig {
|
||||
/// LLM provider settings
|
||||
#[serde(default)]
|
||||
pub llm: LlmConfig,
|
||||
|
||||
/// RAG system settings
|
||||
#[serde(default)]
|
||||
pub rag: RagConfig,
|
||||
|
||||
/// Microservice settings
|
||||
#[serde(default)]
|
||||
pub microservice: MicroserviceConfig,
|
||||
|
||||
/// Appearance settings
|
||||
#[serde(default)]
|
||||
pub appearance: AppearanceConfig,
|
||||
}
|
||||
|
||||
impl TypeDialogAiConfig {
|
||||
/// Load configuration from TOML file
|
||||
pub fn load_from_file(path: &std::path::Path) -> Result<Self> {
|
||||
let content = std::fs::read_to_string(path).map_err(|e| {
|
||||
Error::validation_failed(format!(
|
||||
"Failed to read config file '{}': {}",
|
||||
path.display(),
|
||||
e
|
||||
))
|
||||
})?;
|
||||
|
||||
toml::from_str(&content).map_err(|e| {
|
||||
Error::validation_failed(format!(
|
||||
"Failed to parse config file '{}': {}",
|
||||
path.display(),
|
||||
e
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
/// Load configuration with CLI override
|
||||
/// If cli_config_path is provided, use that file exclusively.
|
||||
/// Otherwise, search in order:
|
||||
/// 1. ~/.config/typedialog/ai/{TYPEDIALOG_ENV}.toml (if TYPEDIALOG_ENV set)
|
||||
/// 2. ~/.config/typedialog/ai/config.toml
|
||||
/// 3. Default values
|
||||
pub fn load_with_cli(cli_config_path: Option<&std::path::Path>) -> Result<Self> {
|
||||
// If CLI path provided, use it exclusively
|
||||
if let Some(path) = cli_config_path {
|
||||
return Self::load_from_file(path);
|
||||
}
|
||||
|
||||
// Otherwise use search order
|
||||
Self::load()
|
||||
}
|
||||
|
||||
/// Load configuration from environment or defaults
|
||||
/// Search order:
|
||||
/// 1. ~/.config/typedialog/ai/production.toml (if TYPEDIALOG_ENV=production)
|
||||
/// 2. ~/.config/typedialog/ai/dev.toml (if TYPEDIALOG_ENV=dev)
|
||||
/// 3. ~/.config/typedialog/ai/config.toml
|
||||
/// 4. Default values
|
||||
pub fn load() -> Result<Self> {
|
||||
let config_dir = dirs::config_dir()
|
||||
.unwrap_or_else(|| {
|
||||
std::path::PathBuf::from(std::env::var("HOME").unwrap_or_else(|_| ".".to_string()))
|
||||
})
|
||||
.join("typedialog/ai");
|
||||
|
||||
// Create directory if it doesn't exist
|
||||
std::fs::create_dir_all(&config_dir).ok();
|
||||
|
||||
// Check environment
|
||||
let env = std::env::var("TYPEDIALOG_ENV").unwrap_or_else(|_| "default".to_string());
|
||||
|
||||
// Try environment-specific config first
|
||||
let env_config_path = config_dir.join(format!("{}.toml", env));
|
||||
if env_config_path.exists() {
|
||||
return Self::load_from_file(&env_config_path);
|
||||
}
|
||||
|
||||
// Try generic config.toml
|
||||
let generic_config_path = config_dir.join("config.toml");
|
||||
if generic_config_path.exists() {
|
||||
return Self::load_from_file(&generic_config_path);
|
||||
}
|
||||
|
||||
// Return defaults
|
||||
Ok(Self::default())
|
||||
}
|
||||
|
||||
/// Resolve paths (expand ~ to home directory)
|
||||
pub fn resolve_paths(&mut self) {
|
||||
if let Some(home) = dirs::home_dir() {
|
||||
let home_str = home.to_string_lossy();
|
||||
let index_str = self.rag.index_path.to_string_lossy();
|
||||
if index_str.starts_with("~") {
|
||||
let resolved = index_str.replace("~", &home_str);
|
||||
self.rag.index_path = PathBuf::from(resolved);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[test]
|
||||
fn test_default_config() {
|
||||
let config = TypeDialogAiConfig::default();
|
||||
assert_eq!(config.llm.provider, LlmProvider::OpenAi);
|
||||
assert_eq!(config.llm.model, "gpt-3.5-turbo");
|
||||
assert_eq!(config.microservice.port, 3001);
|
||||
assert!(config.rag.enabled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_llm_provider_display() {
|
||||
assert_eq!(LlmProvider::OpenAi.to_string(), "openai");
|
||||
assert_eq!(LlmProvider::Anthropic.to_string(), "anthropic");
|
||||
assert_eq!(LlmProvider::Ollama.to_string(), "ollama");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_interaction_mode_display() {
|
||||
assert_eq!(InteractionMode::Interactive.to_string(), "interactive");
|
||||
assert_eq!(InteractionMode::Autocomplete.to_string(), "autocomplete");
|
||||
assert_eq!(InteractionMode::ValidateOnly.to_string(), "validate_only");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_llm_provider_parse() {
|
||||
assert_eq!(
|
||||
"openai".parse::<LlmProvider>().unwrap(),
|
||||
LlmProvider::OpenAi
|
||||
);
|
||||
assert_eq!(
|
||||
"anthropic".parse::<LlmProvider>().unwrap(),
|
||||
LlmProvider::Anthropic
|
||||
);
|
||||
assert_eq!(
|
||||
"ollama".parse::<LlmProvider>().unwrap(),
|
||||
LlmProvider::Ollama
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_default_config_file() {
|
||||
let mut config_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
config_path.pop();
|
||||
config_path.pop();
|
||||
config_path.push("config/ai/default.toml");
|
||||
|
||||
if config_path.exists() {
|
||||
let config = TypeDialogAiConfig::load_from_file(&config_path)
|
||||
.expect("Failed to load default.toml");
|
||||
assert_eq!(config.llm.provider, LlmProvider::OpenAi);
|
||||
assert_eq!(config.llm.model, "gpt-3.5-turbo");
|
||||
assert!(config.rag.enabled);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_dev_config_file() {
|
||||
let mut config_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
config_path.pop();
|
||||
config_path.pop();
|
||||
config_path.push("config/ai/dev.toml");
|
||||
|
||||
if config_path.exists() {
|
||||
let config =
|
||||
TypeDialogAiConfig::load_from_file(&config_path).expect("Failed to load dev.toml");
|
||||
assert_eq!(config.llm.provider, LlmProvider::Ollama);
|
||||
assert!(config.rag.enabled);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_production_config_file() {
|
||||
let mut config_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
config_path.pop();
|
||||
config_path.pop();
|
||||
config_path.push("config/ai/production.toml");
|
||||
|
||||
if config_path.exists() {
|
||||
let config = TypeDialogAiConfig::load_from_file(&config_path)
|
||||
.expect("Failed to load production.toml");
|
||||
assert_eq!(config.llm.provider, LlmProvider::Anthropic);
|
||||
assert!(config.rag.enabled);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_with_cli_explicit_path() {
|
||||
let mut config_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
config_path.pop();
|
||||
config_path.pop();
|
||||
config_path.push("config/ai/dev.toml");
|
||||
|
||||
if config_path.exists() {
|
||||
let config = TypeDialogAiConfig::load_with_cli(Some(config_path.as_path()))
|
||||
.expect("Failed to load with CLI path");
|
||||
assert_eq!(config.llm.provider, LlmProvider::Ollama);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_with_cli_no_path() {
|
||||
// When no CLI path is provided, should use search order (defaults)
|
||||
let config =
|
||||
TypeDialogAiConfig::load_with_cli(None).expect("Failed to load with search order");
|
||||
// Should have at least the defaults
|
||||
assert_eq!(config.llm.provider, LlmProvider::OpenAi);
|
||||
}
|
||||
}
|
||||
32
crates/typedialog-ai/src/lib.rs
Normal file
32
crates/typedialog-ai/src/lib.rs
Normal file
@ -0,0 +1,32 @@
|
||||
#![allow(clippy::module_name_repetitions)]
|
||||
#![allow(clippy::too_many_arguments)]
|
||||
|
||||
//! TypeDialog AI Configuration Assistant Library
|
||||
//!
|
||||
//! Core library for AI-powered configuration assistance.
|
||||
//! Provides AI FormBackend implementation, LLM providers, storage, analysis, and assistant components.
|
||||
|
||||
pub mod assistant;
|
||||
pub mod backend;
|
||||
pub mod config;
|
||||
pub mod llm;
|
||||
pub mod storage;
|
||||
|
||||
pub use assistant::SchemaAnalyzer;
|
||||
pub use backend::{AiBackend, AiBackendMode};
|
||||
pub use config::LlmProvider as LlmProviderType;
|
||||
pub use config::{InteractionMode, TypeDialogAiConfig};
|
||||
pub use storage::SurrealDbClient;
|
||||
|
||||
// LLM exports
|
||||
#[allow(unused_imports)]
|
||||
pub use llm::{GenerationOptions, LlmProvider, Message, Role};
|
||||
|
||||
#[cfg(feature = "openai")]
|
||||
pub use llm::providers::OpenAiProvider;
|
||||
|
||||
#[cfg(feature = "anthropic")]
|
||||
pub use llm::providers::AnthropicProvider;
|
||||
|
||||
#[cfg(feature = "ollama")]
|
||||
pub use llm::providers::OllamaProvider;
|
||||
57
crates/typedialog-ai/src/llm/error.rs
Normal file
57
crates/typedialog-ai/src/llm/error.rs
Normal file
@ -0,0 +1,57 @@
|
||||
//! Error types for LLM operations
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
/// LLM operation errors
|
||||
#[derive(Debug, Error)]
|
||||
pub enum LlmError {
|
||||
#[error("API error: {0}")]
|
||||
ApiError(String),
|
||||
|
||||
#[error("Configuration error: {0}")]
|
||||
ConfigError(String),
|
||||
|
||||
#[error("Network error: {0}")]
|
||||
NetworkError(#[from] reqwest::Error),
|
||||
|
||||
#[error("JSON error: {0}")]
|
||||
JsonError(#[from] serde_json::Error),
|
||||
|
||||
#[error("Provider not available: {0}")]
|
||||
ProviderUnavailable(String),
|
||||
|
||||
#[error("Rate limit exceeded: {0}")]
|
||||
RateLimit(String),
|
||||
|
||||
#[error("Invalid message format: {0}")]
|
||||
InvalidMessage(String),
|
||||
|
||||
#[error("Stream error: {0}")]
|
||||
StreamError(String),
|
||||
|
||||
#[error("Timeout: {0}")]
|
||||
Timeout(String),
|
||||
|
||||
#[error("{0}")]
|
||||
Other(String),
|
||||
}
|
||||
|
||||
/// Result type for LLM operations
|
||||
pub type Result<T> = std::result::Result<T, LlmError>;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_error_display() {
|
||||
let err = LlmError::ConfigError("missing API key".to_string());
|
||||
assert_eq!(err.to_string(), "Configuration error: missing API key");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_creation() {
|
||||
let err: Result<()> = Err(LlmError::ProviderUnavailable("OpenAI".into()));
|
||||
assert!(err.is_err());
|
||||
}
|
||||
}
|
||||
92
crates/typedialog-ai/src/llm/messages.rs
Normal file
92
crates/typedialog-ai/src/llm/messages.rs
Normal file
@ -0,0 +1,92 @@
|
||||
//! Message types for LLM communication
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Message role in conversation
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
/// System prompt (instructions, context)
|
||||
#[serde(rename = "system")]
|
||||
System,
|
||||
|
||||
/// User message
|
||||
#[serde(rename = "user")]
|
||||
User,
|
||||
|
||||
/// Assistant response
|
||||
#[serde(rename = "assistant")]
|
||||
Assistant,
|
||||
}
|
||||
|
||||
/// A message in the conversation
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
/// Role of message sender
|
||||
pub role: Role,
|
||||
|
||||
/// Message content
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
/// Create a system message
|
||||
pub fn system(content: impl Into<String>) -> Self {
|
||||
Message {
|
||||
role: Role::System,
|
||||
content: content.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a user message
|
||||
pub fn user(content: impl Into<String>) -> Self {
|
||||
Message {
|
||||
role: Role::User,
|
||||
content: content.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an assistant message
|
||||
pub fn assistant(content: impl Into<String>) -> Self {
|
||||
Message {
|
||||
role: Role::Assistant,
|
||||
content: content.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get approximate token count (rough estimate: 1 token ≈ 4 chars)
|
||||
pub fn estimate_tokens(&self) -> usize {
|
||||
(self.content.len() / 4) + 10 // Buffer for metadata
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_message_creation() {
|
||||
let msg = Message::user("Hello");
|
||||
assert_eq!(msg.role, Role::User);
|
||||
assert_eq!(msg.content, "Hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_system() {
|
||||
let msg = Message::system("You are helpful");
|
||||
assert_eq!(msg.role, Role::System);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_assistant() {
|
||||
let msg = Message::assistant("Response");
|
||||
assert_eq!(msg.role, Role::Assistant);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_token_estimation() {
|
||||
let msg = Message::user("Hello world");
|
||||
let tokens = msg.estimate_tokens();
|
||||
assert!(tokens > 0);
|
||||
}
|
||||
}
|
||||
114
crates/typedialog-ai/src/llm/mod.rs
Normal file
114
crates/typedialog-ai/src/llm/mod.rs
Normal file
@ -0,0 +1,114 @@
|
||||
//! LLM Provider abstraction for TypeDialog AI service
|
||||
//!
|
||||
//! Provides a trait-based abstraction for multiple LLM backends:
|
||||
//! - OpenAI (GPT-3.5, GPT-4)
|
||||
//! - Claude (Anthropic)
|
||||
//! - Ollama (local models)
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! ```ignore
|
||||
//! use typedialog_ai::llm::{LlmProvider, Message, Role};
|
||||
//!
|
||||
//! #[tokio::main]
|
||||
//! async fn main() -> Result<()> {
|
||||
//! let provider = OpenAiProvider::new("sk-...", "gpt-4")?;
|
||||
//!
|
||||
//! let messages = vec![
|
||||
//! Message {
|
||||
//! role: Role::System,
|
||||
//! content: "You are a helpful assistant".into(),
|
||||
//! },
|
||||
//! Message {
|
||||
//! role: Role::User,
|
||||
//! content: "Hello".into(),
|
||||
//! },
|
||||
//! ];
|
||||
//!
|
||||
//! let response = provider.generate(&messages).await?;
|
||||
//! println!("{}", response);
|
||||
//!
|
||||
//! Ok(())
|
||||
//! }
|
||||
//! ```
|
||||
|
||||
pub mod error;
|
||||
pub mod messages;
|
||||
pub mod options;
|
||||
pub mod prompts;
|
||||
pub mod providers;
|
||||
pub mod rag_integration;
|
||||
|
||||
pub use error::Result;
|
||||
#[allow(unused_imports)]
|
||||
pub use messages::{Message, Role};
|
||||
pub use options::GenerationOptions;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::Stream;
|
||||
use std::pin::Pin;
|
||||
|
||||
/// Type alias for LLM streaming response
|
||||
pub type StreamResponse = Pin<Box<dyn Stream<Item = Result<String>> + Send>>;
|
||||
|
||||
/// LLM Provider trait
|
||||
///
|
||||
/// All LLM backends must implement this trait to be used with TypeDialog AI service.
|
||||
#[async_trait]
|
||||
pub trait LlmProvider: Send + Sync {
|
||||
/// Generate a completion from messages
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `messages` - Conversation history with system prompt
|
||||
/// * `options` - Generation options (temperature, max tokens, etc.)
|
||||
///
|
||||
/// # Returns
|
||||
/// Complete response text
|
||||
async fn generate(&self, messages: &[Message], options: &GenerationOptions) -> Result<String>;
|
||||
|
||||
/// Stream a completion token-by-token
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `messages` - Conversation history
|
||||
/// * `options` - Generation options
|
||||
///
|
||||
/// # Returns
|
||||
/// Stream of tokens
|
||||
async fn stream(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
options: &GenerationOptions,
|
||||
) -> Result<StreamResponse>;
|
||||
|
||||
/// Get provider name
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// Get model name
|
||||
fn model(&self) -> &str;
|
||||
|
||||
/// Check if provider is available (e.g., API key configured)
|
||||
async fn is_available(&self) -> bool;
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "openai"))]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_message_creation() {
|
||||
let msg = Message {
|
||||
role: Role::User,
|
||||
content: "Hello".into(),
|
||||
};
|
||||
|
||||
assert_eq!(msg.role, Role::User);
|
||||
assert_eq!(msg.content, "Hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generation_options_default() {
|
||||
let opts = GenerationOptions::default();
|
||||
assert_eq!(opts.temperature, 0.7);
|
||||
assert_eq!(opts.max_tokens, Some(2048));
|
||||
}
|
||||
}
|
||||
123
crates/typedialog-ai/src/llm/options.rs
Normal file
123
crates/typedialog-ai/src/llm/options.rs
Normal file
@ -0,0 +1,123 @@
|
||||
//! Generation options for LLM inference
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Options for LLM text generation
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GenerationOptions {
|
||||
/// Sampling temperature (0.0 = deterministic, 2.0 = creative)
|
||||
#[serde(default = "default_temperature")]
|
||||
pub temperature: f32,
|
||||
|
||||
/// Maximum tokens to generate
|
||||
#[serde(default = "default_max_tokens")]
|
||||
pub max_tokens: Option<usize>,
|
||||
|
||||
/// Sequences that stop generation
|
||||
#[serde(default)]
|
||||
pub stop_sequences: Vec<String>,
|
||||
|
||||
/// Top-k sampling (keep top k tokens)
|
||||
#[serde(default)]
|
||||
pub top_k: Option<usize>,
|
||||
|
||||
/// Nucleus sampling (cumulative probability threshold)
|
||||
#[serde(default)]
|
||||
pub top_p: Option<f32>,
|
||||
|
||||
/// Presence penalty (-2.0 to 2.0)
|
||||
#[serde(default)]
|
||||
pub presence_penalty: Option<f32>,
|
||||
|
||||
/// Frequency penalty (-2.0 to 2.0)
|
||||
#[serde(default)]
|
||||
pub frequency_penalty: Option<f32>,
|
||||
}
|
||||
|
||||
fn default_temperature() -> f32 {
|
||||
0.7
|
||||
}
|
||||
|
||||
fn default_max_tokens() -> Option<usize> {
|
||||
Some(2048)
|
||||
}
|
||||
|
||||
impl Default for GenerationOptions {
|
||||
fn default() -> Self {
|
||||
GenerationOptions {
|
||||
temperature: default_temperature(),
|
||||
max_tokens: default_max_tokens(),
|
||||
stop_sequences: vec![],
|
||||
top_k: None,
|
||||
top_p: None,
|
||||
presence_penalty: None,
|
||||
frequency_penalty: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl GenerationOptions {
|
||||
/// Create options optimized for reasoning/analysis
|
||||
pub fn analytical() -> Self {
|
||||
GenerationOptions {
|
||||
temperature: 0.3,
|
||||
max_tokens: Some(4096),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create options optimized for creative generation
|
||||
pub fn creative() -> Self {
|
||||
GenerationOptions {
|
||||
temperature: 1.2,
|
||||
max_tokens: Some(2048),
|
||||
top_p: Some(0.95),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create options for config generation (deterministic, concise)
|
||||
pub fn config_generation() -> Self {
|
||||
GenerationOptions {
|
||||
temperature: 0.1,
|
||||
max_tokens: Some(8192),
|
||||
stop_sequences: vec!["```".to_string()],
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_default_options() {
|
||||
let opts = GenerationOptions::default();
|
||||
assert_eq!(opts.temperature, 0.7);
|
||||
assert_eq!(opts.max_tokens, Some(2048));
|
||||
assert!(opts.stop_sequences.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_analytical_options() {
|
||||
let opts = GenerationOptions::analytical();
|
||||
assert_eq!(opts.temperature, 0.3);
|
||||
assert_eq!(opts.max_tokens, Some(4096));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_creative_options() {
|
||||
let opts = GenerationOptions::creative();
|
||||
assert_eq!(opts.temperature, 1.2);
|
||||
assert_eq!(opts.top_p, Some(0.95));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_options() {
|
||||
let opts = GenerationOptions::config_generation();
|
||||
assert_eq!(opts.temperature, 0.1);
|
||||
assert_eq!(opts.max_tokens, Some(8192));
|
||||
assert!(!opts.stop_sequences.is_empty());
|
||||
}
|
||||
}
|
||||
3
crates/typedialog-ai/src/llm/prompts/mod.rs
Normal file
3
crates/typedialog-ai/src/llm/prompts/mod.rs
Normal file
@ -0,0 +1,3 @@
|
||||
//! Prompt templates and engineering for configuration assistant
|
||||
|
||||
pub mod system;
|
||||
99
crates/typedialog-ai/src/llm/prompts/system.rs
Normal file
99
crates/typedialog-ai/src/llm/prompts/system.rs
Normal file
@ -0,0 +1,99 @@
|
||||
//! System prompts for different conversation modes
|
||||
|
||||
/// System prompt for configuration assistant
|
||||
pub fn config_assistant_system() -> String {
|
||||
r#"You are an expert configuration assistant helping users design and configure systems using Nickel schemas.
|
||||
|
||||
Your role is to:
|
||||
1. Ask clarifying questions to understand the user's requirements
|
||||
2. Provide helpful suggestions based on similar configurations
|
||||
3. Explain technical concepts in accessible language
|
||||
4. Help users think through configuration options
|
||||
5. Validate their choices against best practices
|
||||
|
||||
When users provide information about what they need, leverage relevant examples from similar projects to suggest sensible defaults and explain tradeoffs.
|
||||
|
||||
Be conversational but professional. Ask one or two questions at a time to avoid overwhelming users. Acknowledge their requirements and confirm understanding before suggesting solutions.
|
||||
|
||||
When suggesting values, explain WHY those values are recommended."#
|
||||
.to_string()
|
||||
}
|
||||
|
||||
/// System prompt for schema analysis
|
||||
pub fn schema_analyzer_system() -> String {
|
||||
r#"You are an expert in understanding configuration schemas and guiding users through schema-based forms.
|
||||
|
||||
Analyze the provided schema and help users understand:
|
||||
1. What each field does
|
||||
2. Why it's important
|
||||
3. What values are typically used
|
||||
4. How fields relate to each other
|
||||
|
||||
When examining schemas, focus on the semantic meaning, not just syntax. Help users understand the intent behind each field.
|
||||
|
||||
Provide context-specific help. If a field controls database configuration, explain database-specific considerations. If it's about security settings, explain security implications."#
|
||||
.to_string()
|
||||
}
|
||||
|
||||
/// System prompt for config generation
|
||||
pub fn config_generation_system() -> String {
|
||||
r#"You are an expert at generating valid Nickel configuration files.
|
||||
|
||||
Based on the user's requirements and choices, generate complete, valid Nickel configurations that:
|
||||
1. Include all required fields
|
||||
2. Follow best practices and defaults from examples
|
||||
3. Are properly formatted and syntactically correct
|
||||
4. Include helpful comments explaining non-obvious choices
|
||||
|
||||
Generate ONLY the Nickel configuration. Do not include explanations or markdown code blocks.
|
||||
Output the raw Nickel code that can be directly used.
|
||||
|
||||
Ensure:
|
||||
- All strings are properly quoted
|
||||
- All records have proper formatting
|
||||
- Array syntax is correct
|
||||
- Comments use proper Nickel syntax (# for single-line comments)"#
|
||||
.to_string()
|
||||
}
|
||||
|
||||
/// System prompt for validation and debugging
|
||||
pub fn validator_system() -> String {
|
||||
r#"You are an expert at validating configurations and helping users fix configuration errors.
|
||||
|
||||
When given a configuration and its validation errors:
|
||||
1. Explain what each error means
|
||||
2. Suggest concrete fixes
|
||||
3. Explain why the fix is correct
|
||||
4. Help users understand how to prevent similar errors
|
||||
|
||||
Be practical and direct. Users want to understand the problem and fix it quickly."#
|
||||
.to_string()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_system_prompts_not_empty() {
|
||||
assert!(!config_assistant_system().is_empty());
|
||||
assert!(!schema_analyzer_system().is_empty());
|
||||
assert!(!config_generation_system().is_empty());
|
||||
assert!(!validator_system().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_system_prompts_contain_guidance() {
|
||||
let assistant = config_assistant_system();
|
||||
assert!(assistant.contains("clarifying questions"));
|
||||
|
||||
let schema = schema_analyzer_system();
|
||||
assert!(schema.contains("schema"));
|
||||
|
||||
let config = config_generation_system();
|
||||
assert!(config.contains("Nickel"));
|
||||
|
||||
let validator = validator_system();
|
||||
assert!(validator.contains("error"));
|
||||
}
|
||||
}
|
||||
333
crates/typedialog-ai/src/llm/providers/anthropic.rs
Normal file
333
crates/typedialog-ai/src/llm/providers/anthropic.rs
Normal file
@ -0,0 +1,333 @@
|
||||
//! Anthropic Claude LLM provider
|
||||
|
||||
use super::super::error::{LlmError, Result};
|
||||
use super::super::messages::{Message, Role};
|
||||
use super::super::options::GenerationOptions;
|
||||
use super::super::{LlmProvider, StreamResponse};
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::StreamExt;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
|
||||
const ANTHROPIC_API_URL: &str = "https://api.anthropic.com/v1/messages";
|
||||
|
||||
/// Anthropic API request
|
||||
#[derive(Debug, Serialize)]
|
||||
struct AnthropicRequest {
|
||||
model: String,
|
||||
max_tokens: usize,
|
||||
system: String,
|
||||
messages: Vec<AnthropicMessage>,
|
||||
temperature: f32,
|
||||
stream: bool,
|
||||
}
|
||||
|
||||
/// Anthropic message format
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct AnthropicMessage {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
/// Anthropic API response
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AnthropicResponse {
|
||||
content: Vec<ContentBlock>,
|
||||
}
|
||||
|
||||
/// Content block in response
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ContentBlock {
|
||||
#[serde(default)]
|
||||
text: String,
|
||||
}
|
||||
|
||||
/// Streaming event
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
enum StreamEvent {
|
||||
ContentBlockDelta {
|
||||
delta: StreamDelta,
|
||||
},
|
||||
#[serde(other)]
|
||||
Other,
|
||||
}
|
||||
|
||||
/// Stream delta
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct StreamDelta {
|
||||
#[serde(default)]
|
||||
text: String,
|
||||
}
|
||||
|
||||
/// Anthropic provider
|
||||
pub struct AnthropicProvider {
|
||||
api_key: String,
|
||||
model: String,
|
||||
client: Arc<Client>,
|
||||
api_version: String,
|
||||
}
|
||||
|
||||
impl AnthropicProvider {
|
||||
/// Create new Anthropic provider
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `api_key` - Anthropic API key
|
||||
/// * `model` - Model name (e.g., "claude-3-opus", "claude-3-sonnet")
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns error if api_key is empty
|
||||
pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Result<Self> {
|
||||
let api_key = api_key.into();
|
||||
if api_key.is_empty() {
|
||||
return Err(LlmError::ConfigError("Anthropic API key is empty".into()));
|
||||
}
|
||||
|
||||
Ok(AnthropicProvider {
|
||||
api_key,
|
||||
model: model.into(),
|
||||
client: Arc::new(Client::new()),
|
||||
api_version: "2023-06-01".to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Create from environment variable `ANTHROPIC_API_KEY`
|
||||
pub fn from_env(model: impl Into<String>) -> Result<Self> {
|
||||
let api_key = std::env::var("ANTHROPIC_API_KEY")
|
||||
.map_err(|_| LlmError::ConfigError("ANTHROPIC_API_KEY not set".into()))?;
|
||||
Self::new(api_key, model)
|
||||
}
|
||||
|
||||
fn build_request(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
options: &GenerationOptions,
|
||||
) -> Result<AnthropicRequest> {
|
||||
// Extract system message if present
|
||||
let system = messages
|
||||
.iter()
|
||||
.find(|m| m.role == Role::System)
|
||||
.map(|m| m.content.clone())
|
||||
.unwrap_or_else(|| "You are a helpful assistant".to_string());
|
||||
|
||||
// Filter out system messages from message list (Anthropic expects them separately)
|
||||
let anthropic_messages = messages
|
||||
.iter()
|
||||
.filter(|m| m.role != Role::System)
|
||||
.map(|msg| AnthropicMessage {
|
||||
role: match msg.role {
|
||||
Role::System => unreachable!(),
|
||||
Role::User => "user",
|
||||
Role::Assistant => "assistant",
|
||||
}
|
||||
.to_string(),
|
||||
content: msg.content.clone(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(AnthropicRequest {
|
||||
model: self.model.clone(),
|
||||
max_tokens: options.max_tokens.unwrap_or(2048),
|
||||
system,
|
||||
messages: anthropic_messages,
|
||||
temperature: options.temperature.clamp(0.0, 1.0),
|
||||
stream: false,
|
||||
})
|
||||
}
|
||||
|
||||
fn build_stream_request(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
options: &GenerationOptions,
|
||||
) -> Result<AnthropicRequest> {
|
||||
let mut req = self.build_request(messages, options)?;
|
||||
req.stream = true;
|
||||
Ok(req)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LlmProvider for AnthropicProvider {
|
||||
async fn generate(&self, messages: &[Message], options: &GenerationOptions) -> Result<String> {
|
||||
let req = self.build_request(messages, options)?;
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(ANTHROPIC_API_URL)
|
||||
.header("x-api-key", &self.api_key)
|
||||
.header("anthropic-version", &self.api_version)
|
||||
.header("content-type", "application/json")
|
||||
.json(&req)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
if e.is_timeout() {
|
||||
LlmError::Timeout("Anthropic request timed out".into())
|
||||
} else if e.status().is_some_and(|s| s.as_u16() == 429) {
|
||||
LlmError::RateLimit("Anthropic rate limited".into())
|
||||
} else {
|
||||
LlmError::NetworkError(e)
|
||||
}
|
||||
})?;
|
||||
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
let error_text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(LlmError::ApiError(format!(
|
||||
"Anthropic API error ({}): {}",
|
||||
status, error_text
|
||||
)));
|
||||
}
|
||||
|
||||
let body: AnthropicResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::ApiError(format!("Failed to parse response: {}", e)))?;
|
||||
|
||||
body.content
|
||||
.into_iter()
|
||||
.next()
|
||||
.map(|block| block.text)
|
||||
.ok_or_else(|| LlmError::ApiError("No content in response".into()))
|
||||
}
|
||||
|
||||
async fn stream(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
options: &GenerationOptions,
|
||||
) -> Result<StreamResponse> {
|
||||
let req = self.build_stream_request(messages, options)?;
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(ANTHROPIC_API_URL)
|
||||
.header("x-api-key", &self.api_key)
|
||||
.header("anthropic-version", &self.api_version)
|
||||
.header("content-type", "application/json")
|
||||
.json(&req)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
if e.is_timeout() {
|
||||
LlmError::Timeout("Anthropic request timed out".into())
|
||||
} else {
|
||||
LlmError::NetworkError(e)
|
||||
}
|
||||
})?;
|
||||
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
return Err(LlmError::ApiError(format!(
|
||||
"Anthropic API error: {}",
|
||||
status
|
||||
)));
|
||||
}
|
||||
|
||||
let stream = response
|
||||
.bytes_stream()
|
||||
.filter_map(|result| async move {
|
||||
match result {
|
||||
Ok(bytes) => {
|
||||
let text = String::from_utf8_lossy(&bytes);
|
||||
let text = text.trim();
|
||||
if text.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
for line in text.lines() {
|
||||
if let Some(data) = line.strip_prefix("data: ") {
|
||||
if let Ok(event) = serde_json::from_str::<StreamEvent>(data) {
|
||||
match event {
|
||||
StreamEvent::ContentBlockDelta { delta } => {
|
||||
if !delta.text.is_empty() {
|
||||
return Some(Ok(delta.text));
|
||||
}
|
||||
}
|
||||
StreamEvent::Other => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
Err(e) => Some(Err(LlmError::StreamError(e.to_string()))),
|
||||
}
|
||||
})
|
||||
.boxed();
|
||||
|
||||
Ok(Box::pin(stream))
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"Anthropic"
|
||||
}
|
||||
|
||||
fn model(&self) -> &str {
|
||||
&self.model
|
||||
}
|
||||
|
||||
async fn is_available(&self) -> bool {
|
||||
let req = AnthropicRequest {
|
||||
model: self.model.clone(),
|
||||
max_tokens: 1,
|
||||
system: "test".to_string(),
|
||||
messages: vec![AnthropicMessage {
|
||||
role: "user".to_string(),
|
||||
content: "test".to_string(),
|
||||
}],
|
||||
temperature: 0.7,
|
||||
stream: false,
|
||||
};
|
||||
|
||||
self.client
|
||||
.post(ANTHROPIC_API_URL)
|
||||
.header("x-api-key", &self.api_key)
|
||||
.header("anthropic-version", &self.api_version)
|
||||
.json(&req)
|
||||
.send()
|
||||
.await
|
||||
.map(|r| r.status().is_success())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_new() {
|
||||
let provider = AnthropicProvider::new("sk-ant-test", "claude-3-opus");
|
||||
assert!(provider.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_anthropic_empty_key() {
|
||||
let provider = AnthropicProvider::new("", "claude-3-opus");
|
||||
assert!(provider.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_name() {
|
||||
let provider = AnthropicProvider::new("sk-ant-test", "claude-3-opus").unwrap();
|
||||
assert_eq!(provider.name(), "Anthropic");
|
||||
assert_eq!(provider.model(), "claude-3-opus");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_request() {
|
||||
let provider = AnthropicProvider::new("sk-ant-test", "claude-3-opus").unwrap();
|
||||
let messages = vec![Message::system("You are helpful"), Message::user("Hello")];
|
||||
let options = GenerationOptions::default();
|
||||
|
||||
let req = provider.build_request(&messages, &options).unwrap();
|
||||
assert_eq!(req.model, "claude-3-opus");
|
||||
assert_eq!(req.messages.len(), 1); // System filtered out
|
||||
assert!(!req.stream);
|
||||
}
|
||||
}
|
||||
22
crates/typedialog-ai/src/llm/providers/mod.rs
Normal file
22
crates/typedialog-ai/src/llm/providers/mod.rs
Normal file
@ -0,0 +1,22 @@
|
||||
//! LLM provider implementations
|
||||
|
||||
#[cfg(feature = "openai")]
|
||||
pub mod openai;
|
||||
|
||||
#[cfg(feature = "anthropic")]
|
||||
pub mod anthropic;
|
||||
|
||||
#[cfg(feature = "ollama")]
|
||||
pub mod ollama;
|
||||
|
||||
#[cfg(feature = "openai")]
|
||||
#[allow(unused_imports)]
|
||||
pub use openai::OpenAiProvider;
|
||||
|
||||
#[cfg(feature = "anthropic")]
|
||||
#[allow(unused_imports)]
|
||||
pub use anthropic::AnthropicProvider;
|
||||
|
||||
#[cfg(feature = "ollama")]
|
||||
#[allow(unused_imports)]
|
||||
pub use ollama::OllamaProvider;
|
||||
266
crates/typedialog-ai/src/llm/providers/ollama.rs
Normal file
266
crates/typedialog-ai/src/llm/providers/ollama.rs
Normal file
@ -0,0 +1,266 @@
|
||||
//! Ollama local LLM provider
|
||||
|
||||
use super::super::error::{LlmError, Result};
|
||||
use super::super::messages::{Message, Role};
|
||||
use super::super::options::GenerationOptions;
|
||||
use super::super::{LlmProvider, StreamResponse};
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::StreamExt;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Ollama API request
|
||||
#[derive(Debug, Serialize)]
|
||||
struct OllamaRequest {
|
||||
model: String,
|
||||
messages: Vec<OllamaMessage>,
|
||||
temperature: f32,
|
||||
stream: bool,
|
||||
}
|
||||
|
||||
/// Ollama message format
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct OllamaMessage {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
/// Ollama API response
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OllamaResponse {
|
||||
message: OllamaMessage,
|
||||
}
|
||||
|
||||
/// Ollama streaming response
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OllamaStreamResponse {
|
||||
message: OllamaMessage,
|
||||
}
|
||||
|
||||
/// Ollama provider
|
||||
pub struct OllamaProvider {
|
||||
model: String,
|
||||
client: Arc<Client>,
|
||||
base_url: String,
|
||||
}
|
||||
|
||||
impl OllamaProvider {
|
||||
/// Create new Ollama provider
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `model` - Model name (e.g., "llama2", "mistral", "neural-chat")
|
||||
/// * `base_url` - Ollama server URL (default: http://localhost:11434)
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns error if model is empty
|
||||
pub fn new(model: impl Into<String>) -> Result<Self> {
|
||||
Self::with_url(model, "http://localhost:11434")
|
||||
}
|
||||
|
||||
/// Create with custom Ollama server URL
|
||||
pub fn with_url(model: impl Into<String>, base_url: impl Into<String>) -> Result<Self> {
|
||||
let model = model.into();
|
||||
if model.is_empty() {
|
||||
return Err(LlmError::ConfigError("Ollama model name is empty".into()));
|
||||
}
|
||||
|
||||
Ok(OllamaProvider {
|
||||
model,
|
||||
client: Arc::new(Client::new()),
|
||||
base_url: base_url.into(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Create from environment variable `OLLAMA_MODEL`
|
||||
pub fn from_env() -> Result<Self> {
|
||||
let model = std::env::var("OLLAMA_MODEL")
|
||||
.map_err(|_| LlmError::ConfigError("OLLAMA_MODEL not set".into()))?;
|
||||
Self::new(model)
|
||||
}
|
||||
|
||||
fn build_request(&self, messages: &[Message], options: &GenerationOptions) -> OllamaRequest {
|
||||
let ollama_messages = messages
|
||||
.iter()
|
||||
.map(|msg| OllamaMessage {
|
||||
role: match msg.role {
|
||||
Role::System => "system",
|
||||
Role::User => "user",
|
||||
Role::Assistant => "assistant",
|
||||
}
|
||||
.to_string(),
|
||||
content: msg.content.clone(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
OllamaRequest {
|
||||
model: self.model.clone(),
|
||||
messages: ollama_messages,
|
||||
temperature: options.temperature.clamp(0.0, 2.0),
|
||||
stream: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn build_stream_request(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
options: &GenerationOptions,
|
||||
) -> OllamaRequest {
|
||||
let mut req = self.build_request(messages, options);
|
||||
req.stream = true;
|
||||
req
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LlmProvider for OllamaProvider {
|
||||
async fn generate(&self, messages: &[Message], options: &GenerationOptions) -> Result<String> {
|
||||
let req = self.build_request(messages, options);
|
||||
let url = format!("{}/api/chat", self.base_url);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.json(&req)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
if e.is_timeout() {
|
||||
LlmError::Timeout("Ollama request timed out".into())
|
||||
} else {
|
||||
LlmError::NetworkError(e)
|
||||
}
|
||||
})?;
|
||||
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
let error_text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(LlmError::ApiError(format!(
|
||||
"Ollama error ({}): {}",
|
||||
status, error_text
|
||||
)));
|
||||
}
|
||||
|
||||
let body: OllamaResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::ApiError(format!("Failed to parse response: {}", e)))?;
|
||||
|
||||
Ok(body.message.content)
|
||||
}
|
||||
|
||||
async fn stream(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
options: &GenerationOptions,
|
||||
) -> Result<StreamResponse> {
|
||||
let req = self.build_stream_request(messages, options);
|
||||
let url = format!("{}/api/chat", self.base_url);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.json(&req)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
if e.is_timeout() {
|
||||
LlmError::Timeout("Ollama request timed out".into())
|
||||
} else {
|
||||
LlmError::NetworkError(e)
|
||||
}
|
||||
})?;
|
||||
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
return Err(LlmError::ApiError(format!("Ollama error: {}", status)));
|
||||
}
|
||||
|
||||
let stream = response
|
||||
.bytes_stream()
|
||||
.filter_map(|result| async move {
|
||||
match result {
|
||||
Ok(bytes) => {
|
||||
let text = String::from_utf8_lossy(&bytes);
|
||||
let text = text.trim();
|
||||
if text.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
if let Ok(chunk) = serde_json::from_str::<OllamaStreamResponse>(text) {
|
||||
if !chunk.message.content.is_empty() {
|
||||
return Some(Ok(chunk.message.content));
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
Err(e) => Some(Err(LlmError::StreamError(e.to_string()))),
|
||||
}
|
||||
})
|
||||
.boxed();
|
||||
|
||||
Ok(Box::pin(stream))
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"Ollama"
|
||||
}
|
||||
|
||||
fn model(&self) -> &str {
|
||||
&self.model
|
||||
}
|
||||
|
||||
async fn is_available(&self) -> bool {
|
||||
self.client
|
||||
.get(format!("{}/api/tags", self.base_url))
|
||||
.send()
|
||||
.await
|
||||
.map(|r| r.status().is_success())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_ollama_new() {
|
||||
let provider = OllamaProvider::new("llama2");
|
||||
assert!(provider.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ollama_empty_model() {
|
||||
let provider = OllamaProvider::new("");
|
||||
assert!(provider.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ollama_with_url() {
|
||||
let provider = OllamaProvider::with_url("llama2", "http://127.0.0.1:11434");
|
||||
assert!(provider.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_name() {
|
||||
let provider = OllamaProvider::new("llama2").unwrap();
|
||||
assert_eq!(provider.name(), "Ollama");
|
||||
assert_eq!(provider.model(), "llama2");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_request() {
|
||||
let provider = OllamaProvider::new("llama2").unwrap();
|
||||
let messages = vec![Message::user("Hello")];
|
||||
let options = GenerationOptions::default();
|
||||
|
||||
let req = provider.build_request(&messages, &options);
|
||||
assert_eq!(req.model, "llama2");
|
||||
assert_eq!(req.messages.len(), 1);
|
||||
assert!(!req.stream);
|
||||
}
|
||||
}
|
||||
325
crates/typedialog-ai/src/llm/providers/openai.rs
Normal file
325
crates/typedialog-ai/src/llm/providers/openai.rs
Normal file
@ -0,0 +1,325 @@
|
||||
//! OpenAI LLM provider (GPT-3.5, GPT-4)
|
||||
|
||||
use super::super::error::{LlmError, Result};
|
||||
use super::super::messages::{Message, Role};
|
||||
use super::super::options::GenerationOptions;
|
||||
use super::super::{LlmProvider, StreamResponse};
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::StreamExt;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
|
||||
const OPENAI_API_URL: &str = "https://api.openai.com/v1/chat/completions";
|
||||
|
||||
/// OpenAI API request
|
||||
#[derive(Debug, Serialize)]
|
||||
struct OpenAiRequest {
|
||||
model: String,
|
||||
messages: Vec<OpenAiMessage>,
|
||||
temperature: f32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
max_tokens: Option<usize>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
top_p: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
presence_penalty: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
frequency_penalty: Option<f32>,
|
||||
stream: bool,
|
||||
}
|
||||
|
||||
/// OpenAI message format
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct OpenAiMessage {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
/// OpenAI API response
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAiResponse {
|
||||
choices: Vec<OpenAiChoice>,
|
||||
}
|
||||
|
||||
/// OpenAI choice (completion option)
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAiChoice {
|
||||
message: OpenAiMessage,
|
||||
#[allow(dead_code)]
|
||||
finish_reason: String,
|
||||
}
|
||||
|
||||
/// Streaming response chunk
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAiStreamChunk {
|
||||
choices: Vec<StreamChoice>,
|
||||
}
|
||||
|
||||
/// Streaming choice
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct StreamChoice {
|
||||
delta: StreamDelta,
|
||||
}
|
||||
|
||||
/// Stream delta content
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct StreamDelta {
|
||||
#[serde(default)]
|
||||
content: String,
|
||||
}
|
||||
|
||||
/// OpenAI provider
|
||||
pub struct OpenAiProvider {
|
||||
api_key: String,
|
||||
model: String,
|
||||
client: Arc<Client>,
|
||||
}
|
||||
|
||||
impl OpenAiProvider {
|
||||
/// Create new OpenAI provider
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `api_key` - OpenAI API key
|
||||
/// * `model` - Model name (e.g., "gpt-4", "gpt-3.5-turbo")
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns error if api_key is empty
|
||||
pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Result<Self> {
|
||||
let api_key = api_key.into();
|
||||
if api_key.is_empty() {
|
||||
return Err(LlmError::ConfigError("OpenAI API key is empty".into()));
|
||||
}
|
||||
|
||||
Ok(OpenAiProvider {
|
||||
api_key,
|
||||
model: model.into(),
|
||||
client: Arc::new(Client::new()),
|
||||
})
|
||||
}
|
||||
|
||||
/// Create from environment variable `OPENAI_API_KEY`
|
||||
pub fn from_env(model: impl Into<String>) -> Result<Self> {
|
||||
let api_key = std::env::var("OPENAI_API_KEY")
|
||||
.map_err(|_| LlmError::ConfigError("OPENAI_API_KEY not set".into()))?;
|
||||
Self::new(api_key, model)
|
||||
}
|
||||
|
||||
fn build_request(&self, messages: &[Message], options: &GenerationOptions) -> OpenAiRequest {
|
||||
let openai_messages = messages
|
||||
.iter()
|
||||
.map(|msg| OpenAiMessage {
|
||||
role: match msg.role {
|
||||
Role::System => "system",
|
||||
Role::User => "user",
|
||||
Role::Assistant => "assistant",
|
||||
}
|
||||
.to_string(),
|
||||
content: msg.content.clone(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
OpenAiRequest {
|
||||
model: self.model.clone(),
|
||||
messages: openai_messages,
|
||||
temperature: options.temperature.clamp(0.0, 2.0),
|
||||
max_tokens: options.max_tokens,
|
||||
top_p: options.top_p,
|
||||
presence_penalty: options.presence_penalty,
|
||||
frequency_penalty: options.frequency_penalty,
|
||||
stream: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn build_stream_request(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
options: &GenerationOptions,
|
||||
) -> OpenAiRequest {
|
||||
let mut req = self.build_request(messages, options);
|
||||
req.stream = true;
|
||||
req
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LlmProvider for OpenAiProvider {
|
||||
async fn generate(&self, messages: &[Message], options: &GenerationOptions) -> Result<String> {
|
||||
let req = self.build_request(messages, options);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(OPENAI_API_URL)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&req)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
if e.is_timeout() {
|
||||
LlmError::Timeout("OpenAI request timed out".into())
|
||||
} else if e.status().is_some_and(|s| s.as_u16() == 429) {
|
||||
LlmError::RateLimit("OpenAI rate limited".into())
|
||||
} else {
|
||||
LlmError::NetworkError(e)
|
||||
}
|
||||
})?;
|
||||
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
let error_text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(LlmError::ApiError(format!(
|
||||
"OpenAI API error ({}): {}",
|
||||
status, error_text
|
||||
)));
|
||||
}
|
||||
|
||||
let body: OpenAiResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::ApiError(format!("Failed to parse response: {}", e)))?;
|
||||
|
||||
body.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.map(|choice| choice.message.content)
|
||||
.ok_or_else(|| LlmError::ApiError("No choices in response".into()))
|
||||
}
|
||||
|
||||
async fn stream(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
options: &GenerationOptions,
|
||||
) -> Result<StreamResponse> {
|
||||
let req = self.build_stream_request(messages, options);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(OPENAI_API_URL)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&req)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
if e.is_timeout() {
|
||||
LlmError::Timeout("OpenAI request timed out".into())
|
||||
} else {
|
||||
LlmError::NetworkError(e)
|
||||
}
|
||||
})?;
|
||||
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
return Err(LlmError::ApiError(format!("OpenAI API error: {}", status)));
|
||||
}
|
||||
|
||||
let stream = response
|
||||
.bytes_stream()
|
||||
.filter_map(|result| async move {
|
||||
match result {
|
||||
Ok(bytes) => {
|
||||
let text = String::from_utf8_lossy(&bytes);
|
||||
let text = text.trim();
|
||||
if text.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
for line in text.lines() {
|
||||
if let Some(data) = line.strip_prefix("data: ") {
|
||||
if data == "[DONE]" {
|
||||
return None;
|
||||
}
|
||||
|
||||
if let Ok(chunk) = serde_json::from_str::<OpenAiStreamChunk>(data) {
|
||||
if let Some(choice) = chunk.choices.first() {
|
||||
if !choice.delta.content.is_empty() {
|
||||
return Some(Ok(choice.delta.content.clone()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
Err(e) => Some(Err(LlmError::StreamError(e.to_string()))),
|
||||
}
|
||||
})
|
||||
.boxed();
|
||||
|
||||
Ok(Box::pin(stream))
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"OpenAI"
|
||||
}
|
||||
|
||||
fn model(&self) -> &str {
|
||||
&self.model
|
||||
}
|
||||
|
||||
async fn is_available(&self) -> bool {
|
||||
// Quick health check - attempt to verify API key
|
||||
let req = serde_json::json!({
|
||||
"model": &self.model,
|
||||
"messages": [{"role": "user", "content": "test"}],
|
||||
"max_tokens": 1
|
||||
});
|
||||
|
||||
self.client
|
||||
.post(OPENAI_API_URL)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&req)
|
||||
.send()
|
||||
.await
|
||||
.map(|r| r.status().is_success())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_openai_new() {
|
||||
let provider = OpenAiProvider::new("sk-test", "gpt-4");
|
||||
assert!(provider.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openai_empty_key() {
|
||||
let provider = OpenAiProvider::new("", "gpt-4");
|
||||
assert!(provider.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_name() {
|
||||
let provider = OpenAiProvider::new("sk-test", "gpt-4").unwrap();
|
||||
assert_eq!(provider.name(), "OpenAI");
|
||||
assert_eq!(provider.model(), "gpt-4");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_request() {
|
||||
let provider = OpenAiProvider::new("sk-test", "gpt-4").unwrap();
|
||||
let messages = vec![Message::user("Hello")];
|
||||
let options = GenerationOptions::default();
|
||||
|
||||
let req = provider.build_request(&messages, &options);
|
||||
assert_eq!(req.model, "gpt-4");
|
||||
assert_eq!(req.messages.len(), 1);
|
||||
assert!(!req.stream);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_stream_request() {
|
||||
let provider = OpenAiProvider::new("sk-test", "gpt-4").unwrap();
|
||||
let messages = vec![Message::user("Hello")];
|
||||
let options = GenerationOptions::default();
|
||||
|
||||
let req = provider.build_stream_request(&messages, &options);
|
||||
assert!(req.stream);
|
||||
}
|
||||
}
|
||||
241
crates/typedialog-ai/src/llm/rag_integration.rs
Normal file
241
crates/typedialog-ai/src/llm/rag_integration.rs
Normal file
@ -0,0 +1,241 @@
|
||||
//! Integration with RAG system for LLM context formatting
|
||||
//!
|
||||
//! This module provides helpers to format RAG retrieval results as useful
|
||||
//! context for LLM prompts in the configuration assistant.
|
||||
|
||||
/// RAG retrieval result (mirrors typedialog_core::ai::rag::RetrievalResult)
|
||||
#[derive(Clone)]
|
||||
pub struct RetrievalResult {
|
||||
/// Document ID/source identifier
|
||||
pub doc_id: String,
|
||||
/// Document content
|
||||
pub content: String,
|
||||
/// Combined relevance score (0-1)
|
||||
pub combined_score: f32,
|
||||
}
|
||||
|
||||
/// Format RAG results as contextual examples for LLM
|
||||
///
|
||||
/// Converts a list of RAG retrieval results into a formatted context
|
||||
/// that can be included in LLM prompts to provide relevant examples.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// use typedialog_ai::llm::rag_integration::{format_rag_context, RetrievalResult};
|
||||
///
|
||||
/// let results = vec![
|
||||
/// RetrievalResult {
|
||||
/// doc_id: "example1".to_string(),
|
||||
/// content: "port = 8080".to_string(),
|
||||
/// combined_score: 0.92,
|
||||
/// },
|
||||
/// ];
|
||||
///
|
||||
/// let context = format_rag_context(&results);
|
||||
/// println!("Context:\n{}", context);
|
||||
/// ```
|
||||
pub fn format_rag_context(results: &[RetrievalResult]) -> String {
|
||||
if results.is_empty() {
|
||||
return "No relevant examples found.".to_string();
|
||||
}
|
||||
|
||||
let mut context = String::from("Here are relevant examples from similar configurations:\n\n");
|
||||
|
||||
for (idx, result) in results.iter().enumerate() {
|
||||
context.push_str(&format!(
|
||||
"Example {} (from {}): [Relevance: {:.1}%]\n{}\n\n",
|
||||
idx + 1,
|
||||
result.doc_id,
|
||||
result.combined_score * 100.0,
|
||||
result.content
|
||||
));
|
||||
}
|
||||
|
||||
context
|
||||
}
|
||||
|
||||
/// Extract field value suggestions from RAG results
|
||||
///
|
||||
/// Analyzes RAG results to extract common values for a specific field name.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// use typedialog_ai::llm::rag_integration::{extract_field_values, RetrievalResult};
|
||||
///
|
||||
/// let results = vec![
|
||||
/// RetrievalResult {
|
||||
/// doc_id: "config1".to_string(),
|
||||
/// content: "port = 8080\nhost = localhost".to_string(),
|
||||
/// combined_score: 0.90,
|
||||
/// },
|
||||
/// ];
|
||||
///
|
||||
/// let values = extract_field_values(&results, "port");
|
||||
/// // Would extract "8080" from the results
|
||||
/// ```
|
||||
pub fn extract_field_values(results: &[RetrievalResult], field_name: &str) -> Vec<String> {
|
||||
let mut values = Vec::new();
|
||||
|
||||
for result in results {
|
||||
// Simple pattern matching for field values
|
||||
// Looks for patterns like "field_name = value" or "field_name: value"
|
||||
let lines = result.content.lines();
|
||||
|
||||
for line in lines {
|
||||
let trimmed = line.trim();
|
||||
|
||||
// Check for TOML/Nickel style: field = value
|
||||
if let Some(idx) = trimmed.find('=') {
|
||||
let key_part = trimmed[..idx].trim();
|
||||
if key_part == field_name {
|
||||
let value_part = trimmed[idx + 1..].trim();
|
||||
let value = clean_value(value_part);
|
||||
if !value.is_empty() && !values.contains(&value) {
|
||||
values.push(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for JSON/YAML style: field: value or "field": value
|
||||
if let Some(idx) = trimmed.find(':') {
|
||||
let key_part = trimmed[..idx].trim().trim_matches('"');
|
||||
if key_part == field_name {
|
||||
let value_part = trimmed[idx + 1..].trim();
|
||||
let value = clean_value(value_part);
|
||||
if !value.is_empty() && !values.contains(&value) {
|
||||
values.push(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
values
|
||||
}
|
||||
|
||||
/// Clean extracted field values
|
||||
fn clean_value(value: &str) -> String {
|
||||
let value = value.trim();
|
||||
|
||||
// Remove trailing comma (TOML syntax)
|
||||
let value = value.strip_suffix(',').unwrap_or(value);
|
||||
|
||||
// Remove quotes
|
||||
let value = value.trim_matches(|c| c == '"' || c == '\'');
|
||||
|
||||
// Remove comments
|
||||
if let Some(idx) = value.find('#') {
|
||||
value[..idx].trim().to_string()
|
||||
} else {
|
||||
value.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
/// Format RAG results with field-specific recommendations
|
||||
///
|
||||
/// Creates a structured recommendation list for a specific field
|
||||
/// based on RAG retrieval results.
|
||||
pub fn format_field_recommendations(
|
||||
field_name: &str,
|
||||
field_description: &str,
|
||||
results: &[RetrievalResult],
|
||||
) -> String {
|
||||
let values = extract_field_values(results, field_name);
|
||||
|
||||
if values.is_empty() {
|
||||
return format!(
|
||||
"No direct examples found for '{}', but {} is an important configuration setting.",
|
||||
field_name, field_description
|
||||
);
|
||||
}
|
||||
|
||||
let mut rec = format!(
|
||||
"For the '{}' field ({}):\nCommon values from similar configurations:\n",
|
||||
field_name, field_description
|
||||
);
|
||||
|
||||
for (idx, value) in values.iter().enumerate() {
|
||||
rec.push_str(&format!(" {}. {}\n", idx + 1, value));
|
||||
}
|
||||
|
||||
rec
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_format_rag_context_empty() {
|
||||
let results = vec![];
|
||||
let context = format_rag_context(&results);
|
||||
assert_eq!(context, "No relevant examples found.");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_rag_context() {
|
||||
let results = vec![RetrievalResult {
|
||||
doc_id: "example1".to_string(),
|
||||
content: "port = 8080".to_string(),
|
||||
combined_score: 0.95,
|
||||
}];
|
||||
|
||||
let context = format_rag_context(&results);
|
||||
assert!(context.contains("Example 1"));
|
||||
assert!(context.contains("example1"));
|
||||
assert!(context.contains("port = 8080"));
|
||||
assert!(context.contains("95.0%"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_field_values_toml() {
|
||||
let results = vec![RetrievalResult {
|
||||
doc_id: "config.toml".to_string(),
|
||||
content: "port = 8080\nhost = localhost".to_string(),
|
||||
combined_score: 0.90,
|
||||
}];
|
||||
|
||||
let values = extract_field_values(&results, "port");
|
||||
assert!(values.contains(&"8080".to_string()));
|
||||
|
||||
let hosts = extract_field_values(&results, "host");
|
||||
assert!(hosts.contains(&"localhost".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_field_values_json() {
|
||||
let results = vec![RetrievalResult {
|
||||
doc_id: "config.json".to_string(),
|
||||
content: r#" "port": 8080,
|
||||
"host": "localhost""#
|
||||
.to_string(),
|
||||
combined_score: 0.90,
|
||||
}];
|
||||
|
||||
let values = extract_field_values(&results, "host");
|
||||
assert!(values.contains(&"localhost".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clean_value() {
|
||||
assert_eq!(clean_value("8080"), "8080");
|
||||
assert_eq!(clean_value("\"localhost\""), "localhost");
|
||||
assert_eq!(clean_value("true,"), "true");
|
||||
assert_eq!(clean_value("value # comment"), "value");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_field_recommendations() {
|
||||
let results = vec![RetrievalResult {
|
||||
doc_id: "example".to_string(),
|
||||
content: "port = 8080\nport = 3000".to_string(),
|
||||
combined_score: 0.90,
|
||||
}];
|
||||
|
||||
let rec = format_field_recommendations("port", "Server port number", &results);
|
||||
assert!(rec.contains("Common values"));
|
||||
assert!(rec.contains("8080") || rec.contains("3000"));
|
||||
}
|
||||
}
|
||||
244
crates/typedialog-ai/src/main.rs
Normal file
244
crates/typedialog-ai/src/main.rs
Normal file
@ -0,0 +1,244 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
//! TypeDialog AI Configuration Assistant Microservice
|
||||
//!
|
||||
//! Provides intelligent configuration assistance through:
|
||||
//! - Conversational dialog with LLM
|
||||
//! - RAG-powered example suggestions
|
||||
//! - Nickel schema understanding
|
||||
//! - Configuration generation and validation
|
||||
//!
|
||||
//! # Running the service
|
||||
//!
|
||||
//! ```bash
|
||||
//! # Start the HTTP server
|
||||
//! typedialog-ai serve --port 3000
|
||||
//!
|
||||
//! # Or run in CLI interactive mode
|
||||
//! typedialog-ai cli --schema schema.ncl
|
||||
//! ```
|
||||
|
||||
mod api;
|
||||
mod assistant;
|
||||
mod backend;
|
||||
mod cli;
|
||||
mod llm;
|
||||
mod storage;
|
||||
mod web_ui;
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::{Parser, Subcommand};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(name = "typedialog-ai")]
|
||||
#[command(about = "AI-powered configuration assistant backend and microservice for TypeDialog", long_about = None)]
|
||||
struct Cli {
|
||||
#[command(subcommand)]
|
||||
command: Option<Commands>,
|
||||
|
||||
/// AI backend configuration file (TOML)
|
||||
///
|
||||
/// If provided, uses this file exclusively.
|
||||
/// If not provided, searches: ~/.config/typedialog/ai/{TYPEDIALOG_ENV}.toml → ~/.config/typedialog/ai/config.toml → defaults
|
||||
#[arg(global = true, short = 'c', long, value_name = "FILE")]
|
||||
config: Option<PathBuf>,
|
||||
|
||||
/// Database path
|
||||
#[arg(global = true, long, default_value = ".typedialog/ai.db")]
|
||||
db_path: String,
|
||||
|
||||
/// SurrealDB namespace
|
||||
#[arg(global = true, long, default_value = "default")]
|
||||
namespace: String,
|
||||
|
||||
/// SurrealDB database
|
||||
#[arg(global = true, long, default_value = "typedialog")]
|
||||
database: String,
|
||||
}
|
||||
|
||||
#[derive(Subcommand)]
|
||||
enum Commands {
|
||||
/// Start the HTTP server
|
||||
#[command(about = "Start the HTTP API server")]
|
||||
Serve {
|
||||
/// Server port
|
||||
#[arg(short, long, default_value = "3000")]
|
||||
port: u16,
|
||||
|
||||
/// Server host
|
||||
#[arg(short = 'H', long, default_value = "127.0.0.1")]
|
||||
host: String,
|
||||
},
|
||||
|
||||
/// Run in CLI interactive mode
|
||||
#[command(about = "Interactive CLI mode")]
|
||||
Cli {
|
||||
/// Nickel schema path
|
||||
#[arg(short, long)]
|
||||
schema: String,
|
||||
|
||||
/// Conversation ID (resume existing)
|
||||
#[arg(long)]
|
||||
conversation_id: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
// Initialize tracing
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(
|
||||
tracing_subscriber::EnvFilter::from_default_env()
|
||||
.add_directive("typedialog_ai=debug".parse()?),
|
||||
)
|
||||
.init();
|
||||
|
||||
let cli = Cli::parse();
|
||||
|
||||
// Load configuration
|
||||
let config_path = cli.config.as_deref();
|
||||
let config = typedialog_ai::config::TypeDialogAiConfig::load_with_cli(config_path)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to load configuration: {}", e))?;
|
||||
|
||||
if let Some(path) = config_path {
|
||||
println!("📋 Using config: {}", path.display());
|
||||
}
|
||||
|
||||
match cli.command {
|
||||
Some(Commands::Serve { port, host }) => {
|
||||
serve_command(
|
||||
&host,
|
||||
port,
|
||||
&cli.db_path,
|
||||
&cli.namespace,
|
||||
&cli.database,
|
||||
config,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
Some(Commands::Cli {
|
||||
schema,
|
||||
conversation_id,
|
||||
}) => {
|
||||
cli_command(
|
||||
&schema,
|
||||
conversation_id,
|
||||
&cli.db_path,
|
||||
&cli.namespace,
|
||||
&cli.database,
|
||||
config,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
None => {
|
||||
// Default to serve
|
||||
serve_command(
|
||||
"127.0.0.1",
|
||||
3000,
|
||||
&cli.db_path,
|
||||
&cli.namespace,
|
||||
&cli.database,
|
||||
config,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn serve_command(
|
||||
host: &str,
|
||||
port: u16,
|
||||
db_path: &str,
|
||||
namespace: &str,
|
||||
database: &str,
|
||||
config: typedialog_ai::config::TypeDialogAiConfig,
|
||||
) -> Result<()> {
|
||||
println!("🚀 Starting TypeDialog AI Service on {}:{}", host, port);
|
||||
println!(
|
||||
"🤖 LLM Provider: {} ({})",
|
||||
config.llm.provider, config.llm.model
|
||||
);
|
||||
println!(
|
||||
"🔍 RAG System: {}",
|
||||
if config.rag.enabled {
|
||||
"enabled"
|
||||
} else {
|
||||
"disabled"
|
||||
}
|
||||
);
|
||||
|
||||
// Initialize database
|
||||
let client = Arc::new(storage::SurrealDbClient::new(db_path, namespace, database).await?);
|
||||
client.initialize().await?;
|
||||
|
||||
println!("✅ Database initialized");
|
||||
println!("📡 API listening on http://{}:{}", host, port);
|
||||
println!("🌐 Web UI at http://{}:{}/ui", host, port);
|
||||
|
||||
// Create application state
|
||||
let app_state = api::AppState {
|
||||
db: client,
|
||||
assistants: Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())),
|
||||
start_time: std::time::Instant::now(),
|
||||
};
|
||||
|
||||
// Create router with all routes
|
||||
let app = api::create_router(app_state);
|
||||
|
||||
// Parse host and create socket address
|
||||
let addr: std::net::SocketAddr = format!("{}:{}", host, port).parse()?;
|
||||
|
||||
// Create listener
|
||||
let listener = tokio::net::TcpListener::bind(&addr).await?;
|
||||
|
||||
tracing::info!("Server listening on {}", addr);
|
||||
|
||||
// Run the server
|
||||
axum::serve(listener, app).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn cli_command(
|
||||
schema: &str,
|
||||
conversation_id: Option<String>,
|
||||
db_path: &str,
|
||||
namespace: &str,
|
||||
database: &str,
|
||||
config: typedialog_ai::config::TypeDialogAiConfig,
|
||||
) -> Result<()> {
|
||||
// Verify schema file exists
|
||||
let schema_path = Path::new(schema);
|
||||
if !schema_path.exists() {
|
||||
anyhow::bail!("Schema file not found: {}", schema);
|
||||
}
|
||||
|
||||
println!(
|
||||
"🤖 LLM Provider: {} ({})",
|
||||
config.llm.provider, config.llm.model
|
||||
);
|
||||
println!(
|
||||
"🔍 RAG System: {}",
|
||||
if config.rag.enabled {
|
||||
"enabled"
|
||||
} else {
|
||||
"disabled"
|
||||
}
|
||||
);
|
||||
|
||||
// Initialize database
|
||||
let client = Arc::new(storage::SurrealDbClient::new(db_path, namespace, database).await?);
|
||||
client.initialize().await?;
|
||||
|
||||
// Create interactive session
|
||||
let mut session = cli::InteractiveSession::new(schema_path, client, conversation_id).await?;
|
||||
|
||||
// Run interactive conversation loop
|
||||
session.run().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
388
crates/typedialog-ai/src/storage/client.rs
Normal file
388
crates/typedialog-ai/src/storage/client.rs
Normal file
@ -0,0 +1,388 @@
|
||||
//! SurrealDB client and connection management
|
||||
//!
|
||||
//! Abstracts SurrealDB operations for the AI configuration assistant.
|
||||
//! Supports both local and remote connections via HTTP.
|
||||
|
||||
use crate::storage::models::*;
|
||||
use anyhow::Result;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// SurrealDB client for AI configuration assistant
|
||||
///
|
||||
/// Manages connections to SurrealDB (local or remote) with support for:
|
||||
/// - Graph relationships between conversations, schemas, fields
|
||||
/// - Full-text search on messages
|
||||
/// - Reusable RAG results
|
||||
/// - Schema validation with SurrealQL contracts
|
||||
pub struct SurrealDbClient {
|
||||
/// Connection endpoint (e.g., "http://localhost:8000", "memory://", "file://./db.db")
|
||||
endpoint: String,
|
||||
/// SurrealDB namespace
|
||||
namespace: String,
|
||||
/// SurrealDB database name
|
||||
database: String,
|
||||
}
|
||||
|
||||
impl SurrealDbClient {
|
||||
/// Create new SurrealDB client
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `endpoint` - Connection endpoint
|
||||
/// - `"memory://"` - In-memory (development, data lost on restart)
|
||||
/// - `"file://./typedialog.db"` - Local RocksDB file (development)
|
||||
/// - `"http://localhost:8000"` - Local HTTP server
|
||||
/// - `"https://surreal.example.com"` - Remote HTTP endpoint (production)
|
||||
/// * `namespace` - SurrealDB namespace (e.g., "default")
|
||||
/// * `database` - SurrealDB database name (e.g., "typedialog")
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns error if endpoint format is invalid
|
||||
pub async fn new(endpoint: &str, namespace: &str, database: &str) -> Result<Self> {
|
||||
// Validate endpoint
|
||||
let valid = endpoint == "memory://"
|
||||
|| endpoint.starts_with("file://")
|
||||
|| endpoint.starts_with("http");
|
||||
if !valid {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Invalid endpoint. Use 'memory://', 'file://<path>', or 'http(s)://<host>:<port>'"
|
||||
));
|
||||
}
|
||||
|
||||
tracing::debug!(
|
||||
endpoint = %endpoint,
|
||||
namespace = %namespace,
|
||||
database = %database,
|
||||
"Creating SurrealDB client"
|
||||
);
|
||||
|
||||
Ok(SurrealDbClient {
|
||||
endpoint: endpoint.to_string(),
|
||||
namespace: namespace.to_string(),
|
||||
database: database.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Get connection endpoint
|
||||
pub fn endpoint(&self) -> &str {
|
||||
&self.endpoint
|
||||
}
|
||||
|
||||
/// Get namespace
|
||||
pub fn namespace(&self) -> &str {
|
||||
&self.namespace
|
||||
}
|
||||
|
||||
/// Get database name
|
||||
pub fn database(&self) -> &str {
|
||||
&self.database
|
||||
}
|
||||
|
||||
/// Initialize database schema
|
||||
///
|
||||
/// Creates necessary tables, indexes, and relationships:
|
||||
/// - `conversations` - Main conversation records with status tracking
|
||||
/// - `messages` - Individual messages with optional RAG result links
|
||||
/// - `rag_results` - Reusable RAG retrieval results
|
||||
/// - `schemas` - Nickel schema definitions
|
||||
/// - `fields` - Extracted schema field metadata
|
||||
///
|
||||
/// All tables include proper indexing:
|
||||
/// - Foreign key relationships for graph queries
|
||||
/// - Full-text search indexes on messages
|
||||
/// - Status and timestamp indexes for efficient filtering
|
||||
pub async fn initialize(&self) -> Result<()> {
|
||||
tracing::info!(
|
||||
endpoint = %self.endpoint,
|
||||
"Initializing SurrealDB schema (ns='{}', db='{}')",
|
||||
self.namespace,
|
||||
self.database
|
||||
);
|
||||
|
||||
// TODO: Execute SurrealQL schema initialization via HTTP client
|
||||
// This would create:
|
||||
// - DEFINE TABLE conversations SCHEMAFULL;
|
||||
// - DEFINE TABLE messages SCHEMAFULL;
|
||||
// - DEFINE TABLE rag_results SCHEMAFULL;
|
||||
// - DEFINE TABLE schemas SCHEMAFULL;
|
||||
// - DEFINE TABLE fields SCHEMAFULL;
|
||||
// - DEFINE INDEX indexes for efficient queries
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// CONVERSATION OPERATIONS
|
||||
// ============================================================================
|
||||
|
||||
/// Create new conversation
|
||||
///
|
||||
/// Initializes a new conversation for configuring a Nickel schema
|
||||
pub async fn create_conversation(&self, schema_id: &str) -> Result<String> {
|
||||
let id = Uuid::new_v4().to_string();
|
||||
tracing::debug!(conversation_id = %id, schema_id = %schema_id, "Creating conversation");
|
||||
|
||||
// TODO: INSERT INTO conversations (id, schema_id, created_at, updated_at, status, collected_fields)
|
||||
// VALUES (:id, :schema_id, now(), now(), 'active', {})
|
||||
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
/// Get conversation by ID
|
||||
pub async fn get_conversation(&self, _id: &str) -> Result<Option<Conversation>> {
|
||||
// TODO: SELECT * FROM conversations WHERE id = :id
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/// List conversations with filters
|
||||
pub async fn list_conversations(
|
||||
&self,
|
||||
_filter: &ConversationFilter,
|
||||
) -> Result<Vec<Conversation>> {
|
||||
// TODO: Build query with filters and execute
|
||||
Ok(vec![])
|
||||
}
|
||||
|
||||
/// Update conversation status
|
||||
pub async fn update_conversation_status(
|
||||
&self,
|
||||
_id: &str,
|
||||
_status: ConversationStatus,
|
||||
) -> Result<()> {
|
||||
// TODO: UPDATE conversations SET status = :status, updated_at = now() WHERE id = :id
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update collected fields
|
||||
pub async fn update_collected_fields(
|
||||
&self,
|
||||
_id: &str,
|
||||
_fields: serde_json::Value,
|
||||
) -> Result<()> {
|
||||
// TODO: UPDATE conversations SET collected_fields = :fields, updated_at = now() WHERE id = :id
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// MESSAGE OPERATIONS
|
||||
// ============================================================================
|
||||
|
||||
/// Create new message in conversation
|
||||
pub async fn create_message(
|
||||
&self,
|
||||
conversation_id: &str,
|
||||
role: MessageRole,
|
||||
_content: &str,
|
||||
) -> Result<String> {
|
||||
let id = Uuid::new_v4().to_string();
|
||||
tracing::debug!(message_id = %id, conversation_id = %conversation_id, role = ?role, "Creating message");
|
||||
|
||||
// TODO: INSERT INTO messages (id, conversation_id, role, content, timestamp, rag_result_ids)
|
||||
// VALUES (:id, :conversation_id, :role, :content, now(), [])
|
||||
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
/// Get message by ID
|
||||
pub async fn get_message(&self, _id: &str) -> Result<Option<Message>> {
|
||||
// TODO: SELECT * FROM messages WHERE id = :id
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/// List messages with filters
|
||||
pub async fn list_messages(&self, _filter: &MessageFilter) -> Result<Vec<Message>> {
|
||||
// TODO: Build query with filters, order by timestamp, apply limit/offset
|
||||
Ok(vec![])
|
||||
}
|
||||
|
||||
/// Search messages with full-text search
|
||||
///
|
||||
/// Uses SurrealDB full-text search to find messages containing specific terms
|
||||
pub async fn search_messages(
|
||||
&self,
|
||||
conversation_id: &str,
|
||||
query: &str,
|
||||
) -> Result<Vec<Message>> {
|
||||
tracing::debug!(conversation_id = %conversation_id, query = %query, "Searching messages");
|
||||
|
||||
// TODO: SELECT * FROM messages WHERE conversation_id = :conversation_id AND content CONTAINS :query
|
||||
Ok(vec![])
|
||||
}
|
||||
|
||||
/// Link RAG results to message
|
||||
///
|
||||
/// Associates RAG retrieval results with a message for relationship queries
|
||||
pub async fn link_rag_results(
|
||||
&self,
|
||||
_message_id: &str,
|
||||
_rag_result_ids: Vec<String>,
|
||||
) -> Result<()> {
|
||||
// TODO: UPDATE messages SET rag_result_ids = :rag_result_ids WHERE id = :message_id
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// RAG RESULT OPERATIONS
|
||||
// ============================================================================
|
||||
|
||||
/// Create RAG retrieval result
|
||||
///
|
||||
/// Stores reusable RAG results that can be linked to multiple messages
|
||||
pub async fn create_rag_result(
|
||||
&self,
|
||||
_doc_id: &str,
|
||||
_content: &str,
|
||||
_combined_score: f32,
|
||||
_semantic_score: f32,
|
||||
_keyword_score: f32,
|
||||
) -> Result<String> {
|
||||
let id = Uuid::new_v4().to_string();
|
||||
|
||||
// TODO: INSERT INTO rag_results (...) VALUES (...)
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
/// Get RAG result by ID
|
||||
pub async fn get_rag_result(&self, _id: &str) -> Result<Option<RagResult>> {
|
||||
// TODO: SELECT * FROM rag_results WHERE id = :id
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// SCHEMA OPERATIONS
|
||||
// ============================================================================
|
||||
|
||||
/// Create schema record
|
||||
///
|
||||
/// Stores Nickel schema for reuse across conversations
|
||||
pub async fn create_schema(&self, name: &str, _path: &str, _content: &str) -> Result<String> {
|
||||
let id = Uuid::new_v4().to_string();
|
||||
tracing::debug!(schema_id = %id, schema_name = %name, "Creating schema");
|
||||
|
||||
// TODO: INSERT INTO schemas (id, name, path, content, created_at, updated_at) VALUES (...)
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
/// Get schema by ID
|
||||
pub async fn get_schema(&self, _id: &str) -> Result<Option<Schema>> {
|
||||
// TODO: SELECT * FROM schemas WHERE id = :id
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/// Get schema by name
|
||||
pub async fn get_schema_by_name(&self, _name: &str) -> Result<Option<Schema>> {
|
||||
// TODO: SELECT * FROM schemas WHERE name = :name LIMIT 1
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/// Get all schemas
|
||||
pub async fn list_schemas(&self) -> Result<Vec<Schema>> {
|
||||
// TODO: SELECT * FROM schemas ORDER BY created_at DESC
|
||||
Ok(vec![])
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// FIELD OPERATIONS
|
||||
// ============================================================================
|
||||
|
||||
/// Create field definition
|
||||
///
|
||||
/// Stores extracted field metadata from Nickel schema
|
||||
pub async fn create_field(
|
||||
&self,
|
||||
_schema_id: &str,
|
||||
_name: &str,
|
||||
_field_type: &str,
|
||||
_description: &str,
|
||||
_required: bool,
|
||||
) -> Result<String> {
|
||||
let id = Uuid::new_v4().to_string();
|
||||
|
||||
// TODO: INSERT INTO fields (id, schema_id, name, field_type, description, required, created_at)
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
/// Get fields for schema
|
||||
pub async fn get_schema_fields(&self, _schema_id: &str) -> Result<Vec<Field>> {
|
||||
// TODO: SELECT * FROM fields WHERE schema_id = :schema_id
|
||||
Ok(vec![])
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_client_creation() {
|
||||
let client = SurrealDbClient::new("memory://", "default", "typedialog")
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(client.endpoint(), "memory://");
|
||||
assert_eq!(client.namespace(), "default");
|
||||
assert_eq!(client.database(), "typedialog");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_invalid_endpoint() {
|
||||
let result = SurrealDbClient::new("invalid://endpoint", "default", "typedialog").await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_file_endpoint() {
|
||||
let client = SurrealDbClient::new("file://./test.db", "default", "typedialog")
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(client.endpoint(), "file://./test.db");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_http_endpoint() {
|
||||
let client = SurrealDbClient::new("http://localhost:8000", "default", "typedialog")
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(client.endpoint(), "http://localhost:8000");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_initialize() {
|
||||
let client = SurrealDbClient::new("memory://", "default", "typedialog")
|
||||
.await
|
||||
.unwrap();
|
||||
let result = client.initialize().await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_conversation() {
|
||||
let client = SurrealDbClient::new("memory://", "default", "typedialog")
|
||||
.await
|
||||
.unwrap();
|
||||
let result = client.create_conversation("schema1").await;
|
||||
assert!(result.is_ok());
|
||||
let id = result.unwrap();
|
||||
// ID should be a UUID
|
||||
assert!(!id.is_empty());
|
||||
assert_eq!(id.len(), 36); // UUID format: 8-4-4-4-12
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_message() {
|
||||
let client = SurrealDbClient::new("memory://", "default", "typedialog")
|
||||
.await
|
||||
.unwrap();
|
||||
let result = client
|
||||
.create_message("conv1", MessageRole::User, "Hello")
|
||||
.await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_schema() {
|
||||
let client = SurrealDbClient::new("memory://", "default", "typedialog")
|
||||
.await
|
||||
.unwrap();
|
||||
let result = client.create_schema("config", "schema.ncl", "{}").await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
}
|
||||
18
crates/typedialog-ai/src/storage/mod.rs
Normal file
18
crates/typedialog-ai/src/storage/mod.rs
Normal file
@ -0,0 +1,18 @@
|
||||
//! SurrealDB storage layer for TypeDialog AI service
|
||||
//!
|
||||
//! Provides persistent storage with:
|
||||
//! - Separate tables for conversations, messages, RAG results, schemas, fields
|
||||
//! - Graph relationships between entities
|
||||
//! - Full-text search on messages
|
||||
//! - Schema validation with SurrealQL contracts
|
||||
//! - Reusable messages and RAG results
|
||||
|
||||
pub mod client;
|
||||
pub mod models;
|
||||
|
||||
pub use client::SurrealDbClient;
|
||||
#[allow(unused_imports)]
|
||||
pub use models::{
|
||||
Conversation, ConversationFilter, ConversationStatus, Field, Message, MessageFilter,
|
||||
MessageRole, RagResult, Schema,
|
||||
};
|
||||
254
crates/typedialog-ai/src/storage/models.rs
Normal file
254
crates/typedialog-ai/src/storage/models.rs
Normal file
@ -0,0 +1,254 @@
|
||||
//! Data models for SurrealDB storage
|
||||
//!
|
||||
//! Defines the schema for all SurrealDB tables with support for:
|
||||
//! - Graph relationships between conversations, schemas, fields
|
||||
//! - Full-text search on messages
|
||||
//! - Reusable messages and RAG results
|
||||
//! - Schema validation with SurrealQL contracts
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Conversation record in SurrealDB
|
||||
///
|
||||
/// Represents a user conversation session with a particular schema
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Conversation {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<String>,
|
||||
|
||||
/// Schema being configured
|
||||
pub schema_id: String,
|
||||
|
||||
/// Creation timestamp
|
||||
pub created_at: DateTime<Utc>,
|
||||
|
||||
/// Last update timestamp
|
||||
pub updated_at: DateTime<Utc>,
|
||||
|
||||
/// Conversation status
|
||||
pub status: ConversationStatus,
|
||||
|
||||
/// Collected field values as JSON
|
||||
#[serde(default)]
|
||||
pub collected_fields: serde_json::Value,
|
||||
}
|
||||
|
||||
/// Conversation status
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum ConversationStatus {
|
||||
#[serde(rename = "active")]
|
||||
Active,
|
||||
#[serde(rename = "completed")]
|
||||
Completed,
|
||||
#[serde(rename = "abandoned")]
|
||||
Abandoned,
|
||||
}
|
||||
|
||||
/// Message record in SurrealDB
|
||||
///
|
||||
/// Individual message in a conversation
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<String>,
|
||||
|
||||
/// Conversation this message belongs to
|
||||
pub conversation_id: String,
|
||||
|
||||
/// Message role
|
||||
pub role: MessageRole,
|
||||
|
||||
/// Message content
|
||||
pub content: String,
|
||||
|
||||
/// Message timestamp
|
||||
pub timestamp: DateTime<Utc>,
|
||||
|
||||
/// IDs of RAG results used for this message
|
||||
#[serde(default)]
|
||||
pub rag_result_ids: Vec<String>,
|
||||
|
||||
/// Full-text search index
|
||||
#[serde(skip)]
|
||||
pub search_text: String,
|
||||
}
|
||||
|
||||
/// Message role in conversation
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum MessageRole {
|
||||
#[serde(rename = "user")]
|
||||
User,
|
||||
#[serde(rename = "assistant")]
|
||||
Assistant,
|
||||
}
|
||||
|
||||
/// RAG retrieval result
|
||||
///
|
||||
/// Reusable RAG results that can be linked to multiple messages
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RagResult {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<String>,
|
||||
|
||||
/// Source document ID
|
||||
pub doc_id: String,
|
||||
|
||||
/// Retrieved content
|
||||
pub content: String,
|
||||
|
||||
/// Relevance score (0-1)
|
||||
pub combined_score: f32,
|
||||
|
||||
/// Semantic score component
|
||||
pub semantic_score: f32,
|
||||
|
||||
/// Keyword score component
|
||||
pub keyword_score: f32,
|
||||
|
||||
/// Creation timestamp
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Schema record
|
||||
///
|
||||
/// Stores Nickel schemas for reuse across conversations
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Schema {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<String>,
|
||||
|
||||
/// Schema name
|
||||
pub name: String,
|
||||
|
||||
/// Schema file path
|
||||
pub path: String,
|
||||
|
||||
/// Full Nickel schema content
|
||||
pub content: String,
|
||||
|
||||
/// Creation timestamp
|
||||
pub created_at: DateTime<Utc>,
|
||||
|
||||
/// Last update timestamp
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Schema field definition
|
||||
///
|
||||
/// Individual field extracted from a Nickel schema
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Field {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<String>,
|
||||
|
||||
/// Schema this field belongs to
|
||||
pub schema_id: String,
|
||||
|
||||
/// Field name
|
||||
pub name: String,
|
||||
|
||||
/// Field type (String, Number, Bool, etc.)
|
||||
pub field_type: String,
|
||||
|
||||
/// Human-readable description
|
||||
pub description: String,
|
||||
|
||||
/// Is field required?
|
||||
pub required: bool,
|
||||
|
||||
/// Default value if any
|
||||
pub default_value: Option<String>,
|
||||
|
||||
/// Creation timestamp
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Query filter for messages
|
||||
pub struct MessageFilter {
|
||||
pub conversation_id: Option<String>,
|
||||
pub role: Option<MessageRole>,
|
||||
pub limit: Option<u32>,
|
||||
pub offset: Option<u32>,
|
||||
}
|
||||
|
||||
impl Default for MessageFilter {
|
||||
fn default() -> Self {
|
||||
MessageFilter {
|
||||
conversation_id: None,
|
||||
role: None,
|
||||
limit: Some(50),
|
||||
offset: Some(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Query filter for conversations
|
||||
pub struct ConversationFilter {
|
||||
pub schema_id: Option<String>,
|
||||
pub status: Option<ConversationStatus>,
|
||||
pub limit: Option<u32>,
|
||||
pub offset: Option<u32>,
|
||||
}
|
||||
|
||||
impl Default for ConversationFilter {
|
||||
fn default() -> Self {
|
||||
ConversationFilter {
|
||||
schema_id: None,
|
||||
status: None,
|
||||
limit: Some(50),
|
||||
offset: Some(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[test]
|
||||
fn test_conversation_creation() {
|
||||
let conv = Conversation {
|
||||
id: Some(Uuid::new_v4().to_string()),
|
||||
schema_id: "schema1".to_string(),
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
status: ConversationStatus::Active,
|
||||
collected_fields: serde_json::json!({}),
|
||||
};
|
||||
|
||||
assert_eq!(conv.status, ConversationStatus::Active);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_serialization() {
|
||||
let msg = Message {
|
||||
id: Some(Uuid::new_v4().to_string()),
|
||||
conversation_id: "conv1".to_string(),
|
||||
role: MessageRole::User,
|
||||
content: "Hello".to_string(),
|
||||
timestamp: Utc::now(),
|
||||
rag_result_ids: vec![],
|
||||
search_text: "hello".to_string(),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&msg).unwrap();
|
||||
drop(serde_json::from_str::<Message>(&json).unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rag_result_scores() {
|
||||
let rag = RagResult {
|
||||
id: Some(Uuid::new_v4().to_string()),
|
||||
doc_id: "doc1".to_string(),
|
||||
content: "content".to_string(),
|
||||
combined_score: 0.85,
|
||||
semantic_score: 0.90,
|
||||
keyword_score: 0.80,
|
||||
created_at: Utc::now(),
|
||||
};
|
||||
|
||||
assert!(rag.combined_score >= 0.0 && rag.combined_score <= 1.0);
|
||||
}
|
||||
}
|
||||
24
crates/typedialog-ai/src/web_ui/mod.rs
Normal file
24
crates/typedialog-ai/src/web_ui/mod.rs
Normal file
@ -0,0 +1,24 @@
|
||||
//! Web UI module for serving the interactive web interface
|
||||
|
||||
use axum::response::{Html, IntoResponse};
|
||||
|
||||
/// Serve the main HTML page
|
||||
pub async fn index() -> Html<&'static str> {
|
||||
Html(include_str!("static/index.html"))
|
||||
}
|
||||
|
||||
/// Serve the CSS styles
|
||||
pub async fn styles() -> impl IntoResponse {
|
||||
(
|
||||
[("Content-Type", "text/css")],
|
||||
include_str!("static/styles.css"),
|
||||
)
|
||||
}
|
||||
|
||||
/// Serve the JavaScript application
|
||||
pub async fn app() -> impl IntoResponse {
|
||||
(
|
||||
[("Content-Type", "application/javascript")],
|
||||
include_str!("static/app.js"),
|
||||
)
|
||||
}
|
||||
587
crates/typedialog-ai/src/web_ui/static/app.js
Normal file
587
crates/typedialog-ai/src/web_ui/static/app.js
Normal file
@ -0,0 +1,587 @@
|
||||
/**
|
||||
* TypeDialog AI Configuration Assistant - Web UI
|
||||
* Handles interactive configuration generation with real-time streaming
|
||||
*/
|
||||
|
||||
// ============================================================================
|
||||
// Global State
|
||||
// ============================================================================
|
||||
|
||||
const state = {
|
||||
conversationId: null,
|
||||
selectedSchema: null,
|
||||
ws: null,
|
||||
isConnected: false,
|
||||
isStreaming: false,
|
||||
generatedConfig: null,
|
||||
messageHistory: [],
|
||||
suggestions: [],
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// DOM References
|
||||
// ============================================================================
|
||||
|
||||
const elements = {
|
||||
// Sidebar
|
||||
schemaSelect: document.getElementById('schemaSelect'),
|
||||
startBtn: document.getElementById('startBtn'),
|
||||
copyBtn: document.getElementById('copyBtn'),
|
||||
conversationId: document.getElementById('conversationId'),
|
||||
requiredFields: document.getElementById('requiredFields'),
|
||||
suggestBtn: document.getElementById('suggestBtn'),
|
||||
generateBtn: document.getElementById('generateBtn'),
|
||||
clearBtn: document.getElementById('clearBtn'),
|
||||
|
||||
// Chat
|
||||
chatContainer: document.getElementById('chatContainer'),
|
||||
messageInput: document.getElementById('messageInput'),
|
||||
sendBtn: document.getElementById('sendBtn'),
|
||||
|
||||
// Suggestions
|
||||
suggestionsSection: document.getElementById('suggestionsSection'),
|
||||
suggestionsContainer: document.getElementById('suggestionsContainer'),
|
||||
|
||||
// Config
|
||||
configSection: document.getElementById('configSection'),
|
||||
formatSelect: document.getElementById('formatSelect'),
|
||||
configPreview: document.getElementById('configPreview'),
|
||||
downloadBtn: document.getElementById('downloadBtn'),
|
||||
copyConfigBtn: document.getElementById('copyConfigBtn'),
|
||||
|
||||
// Status
|
||||
statusText: document.getElementById('statusText'),
|
||||
connectionStatus: document.getElementById('connectionStatus'),
|
||||
loadingOverlay: document.getElementById('loadingOverlay'),
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// Initialization
|
||||
// ============================================================================
|
||||
|
||||
document.addEventListener('DOMContentLoaded', () => {
|
||||
setupEventListeners();
|
||||
updateUI();
|
||||
});
|
||||
|
||||
function setupEventListeners() {
|
||||
elements.schemaSelect.addEventListener('change', onSchemaChange);
|
||||
elements.startBtn.addEventListener('click', onStartConversation);
|
||||
elements.copyBtn.addEventListener('click', onCopyConversationId);
|
||||
elements.sendBtn.addEventListener('click', onSendMessage);
|
||||
elements.messageInput.addEventListener('keypress', (e) => {
|
||||
if (e.key === 'Enter' && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
onSendMessage();
|
||||
}
|
||||
});
|
||||
elements.suggestBtn.addEventListener('click', onGetSuggestions);
|
||||
elements.generateBtn.addEventListener('click', onGenerateConfig);
|
||||
elements.downloadBtn.addEventListener('click', onDownloadConfig);
|
||||
elements.copyConfigBtn.addEventListener('click', onCopyConfig);
|
||||
elements.formatSelect.addEventListener('change', onFormatChange);
|
||||
elements.clearBtn.addEventListener('click', onClearConversation);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Event Handlers
|
||||
// ============================================================================
|
||||
|
||||
function onSchemaChange() {
|
||||
state.selectedSchema = elements.schemaSelect.value;
|
||||
updateUI();
|
||||
if (state.selectedSchema) {
|
||||
updateRequiredFields(state.selectedSchema);
|
||||
}
|
||||
}
|
||||
|
||||
function onStartConversation() {
|
||||
if (!state.selectedSchema) {
|
||||
showStatus('Please select a schema', 'error');
|
||||
return;
|
||||
}
|
||||
|
||||
// Generate a conversation ID
|
||||
state.conversationId = generateConversationId();
|
||||
elements.conversationId.textContent = state.conversationId;
|
||||
|
||||
// Clear chat
|
||||
state.messageHistory = [];
|
||||
renderChat();
|
||||
|
||||
// Connect WebSocket
|
||||
connectWebSocket();
|
||||
|
||||
updateUI();
|
||||
showStatus(`Conversation started: ${state.conversationId}`);
|
||||
}
|
||||
|
||||
function onSendMessage() {
|
||||
const message = elements.messageInput.value.trim();
|
||||
if (!message) return;
|
||||
if (!state.ws || !state.isConnected) {
|
||||
showStatus('Not connected', 'error');
|
||||
return;
|
||||
}
|
||||
|
||||
// Add user message to history
|
||||
state.messageHistory.push({
|
||||
role: 'user',
|
||||
content: message,
|
||||
timestamp: new Date(),
|
||||
});
|
||||
|
||||
// Clear input and render
|
||||
elements.messageInput.value = '';
|
||||
renderChat();
|
||||
|
||||
// Send via WebSocket
|
||||
const wsMessage = {
|
||||
type: 'message',
|
||||
content: message,
|
||||
};
|
||||
state.ws.send(JSON.stringify(wsMessage));
|
||||
showStatus('Sending message...');
|
||||
}
|
||||
|
||||
function onGetSuggestions() {
|
||||
if (!state.conversationId) {
|
||||
showStatus('Start a conversation first', 'error');
|
||||
return;
|
||||
}
|
||||
|
||||
showLoading(true);
|
||||
showStatus('Fetching suggestions...');
|
||||
|
||||
// Simulate suggestions (in production, would fetch from API)
|
||||
setTimeout(() => {
|
||||
state.suggestions = [
|
||||
{
|
||||
field: 'port',
|
||||
value: '8080',
|
||||
reasoning: 'Standard web server port',
|
||||
confidence: 0.92,
|
||||
},
|
||||
{
|
||||
field: 'host',
|
||||
value: 'localhost',
|
||||
reasoning: 'Development default',
|
||||
confidence: 0.85,
|
||||
},
|
||||
{
|
||||
field: 'timeout',
|
||||
value: '30',
|
||||
reasoning: 'Balanced responsiveness',
|
||||
confidence: 0.78,
|
||||
},
|
||||
];
|
||||
renderSuggestions();
|
||||
showLoading(false);
|
||||
showStatus('Suggestions loaded');
|
||||
}, 500);
|
||||
}
|
||||
|
||||
function onGenerateConfig() {
|
||||
if (!state.conversationId) {
|
||||
showStatus('Start a conversation first', 'error');
|
||||
return;
|
||||
}
|
||||
|
||||
showLoading(true);
|
||||
showStatus('Generating configuration...');
|
||||
|
||||
// Simulate config generation (in production, would call API)
|
||||
setTimeout(() => {
|
||||
const format = elements.formatSelect.value;
|
||||
state.generatedConfig = generateConfigPreview(format);
|
||||
renderConfigPreview();
|
||||
showLoading(false);
|
||||
showStatus('Configuration generated');
|
||||
}, 1000);
|
||||
}
|
||||
|
||||
function onDownloadConfig() {
|
||||
if (!state.generatedConfig) {
|
||||
showStatus('Generate a configuration first', 'error');
|
||||
return;
|
||||
}
|
||||
|
||||
const format = elements.formatSelect.value;
|
||||
const filename = `config.${format}`;
|
||||
const blob = new Blob([state.generatedConfig], { type: 'text/plain' });
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = document.createElement('a');
|
||||
a.href = url;
|
||||
a.download = filename;
|
||||
a.click();
|
||||
URL.revokeObjectURL(url);
|
||||
|
||||
showStatus('Configuration downloaded');
|
||||
}
|
||||
|
||||
function onCopyConfig() {
|
||||
if (!state.generatedConfig) {
|
||||
showStatus('Generate a configuration first', 'error');
|
||||
return;
|
||||
}
|
||||
|
||||
navigator.clipboard.writeText(state.generatedConfig).then(() => {
|
||||
showStatus('Configuration copied to clipboard');
|
||||
});
|
||||
}
|
||||
|
||||
function onCopyConversationId() {
|
||||
if (!state.conversationId) return;
|
||||
navigator.clipboard.writeText(state.conversationId).then(() => {
|
||||
showStatus('Conversation ID copied');
|
||||
});
|
||||
}
|
||||
|
||||
function onFormatChange() {
|
||||
if (state.generatedConfig) {
|
||||
const format = elements.formatSelect.value;
|
||||
state.generatedConfig = generateConfigPreview(format);
|
||||
renderConfigPreview();
|
||||
}
|
||||
}
|
||||
|
||||
function onClearConversation() {
|
||||
if (!confirm('Clear conversation history?')) return;
|
||||
state.conversationId = null;
|
||||
state.messageHistory = [];
|
||||
state.generatedConfig = null;
|
||||
state.suggestions = [];
|
||||
renderChat();
|
||||
updateUI();
|
||||
showStatus('Conversation cleared');
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// WebSocket Management
|
||||
// ============================================================================
|
||||
|
||||
function connectWebSocket() {
|
||||
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
||||
const wsUrl = `${protocol}//${window.location.host}/ws/${state.conversationId}`;
|
||||
|
||||
try {
|
||||
state.ws = new WebSocket(wsUrl);
|
||||
|
||||
state.ws.onopen = () => {
|
||||
state.isConnected = true;
|
||||
updateUI();
|
||||
showStatus('Connected');
|
||||
};
|
||||
|
||||
state.ws.onmessage = (event) => {
|
||||
try {
|
||||
const response = JSON.parse(event.data);
|
||||
handleWebSocketMessage(response);
|
||||
} catch (e) {
|
||||
console.error('Failed to parse WebSocket message:', e);
|
||||
}
|
||||
};
|
||||
|
||||
state.ws.onerror = (error) => {
|
||||
console.error('WebSocket error:', error);
|
||||
showStatus('Connection error', 'error');
|
||||
};
|
||||
|
||||
state.ws.onclose = () => {
|
||||
state.isConnected = false;
|
||||
updateUI();
|
||||
showStatus('Disconnected', 'error');
|
||||
};
|
||||
} catch (e) {
|
||||
showStatus('Failed to connect', 'error');
|
||||
console.error(e);
|
||||
}
|
||||
}
|
||||
|
||||
function handleWebSocketMessage(response) {
|
||||
switch (response.type) {
|
||||
case 'start':
|
||||
state.isStreaming = true;
|
||||
state.messageHistory.push({
|
||||
role: 'assistant',
|
||||
content: '',
|
||||
timestamp: new Date(),
|
||||
isStreaming: true,
|
||||
});
|
||||
renderChat();
|
||||
break;
|
||||
|
||||
case 'chunk':
|
||||
if (state.messageHistory.length > 0) {
|
||||
const lastMsg = state.messageHistory[state.messageHistory.length - 1];
|
||||
if (lastMsg.role === 'assistant') {
|
||||
lastMsg.content += response.content;
|
||||
}
|
||||
}
|
||||
renderChat();
|
||||
break;
|
||||
|
||||
case 'end':
|
||||
state.isStreaming = false;
|
||||
if (state.messageHistory.length > 0) {
|
||||
const lastMsg = state.messageHistory[state.messageHistory.length - 1];
|
||||
lastMsg.isStreaming = false;
|
||||
}
|
||||
renderChat();
|
||||
showStatus('Message received');
|
||||
break;
|
||||
|
||||
case 'error':
|
||||
showStatus(`Error: ${response.content}`, 'error');
|
||||
break;
|
||||
|
||||
case 'suggestions':
|
||||
try {
|
||||
state.suggestions = JSON.parse(response.content);
|
||||
renderSuggestions();
|
||||
} catch (e) {
|
||||
console.error('Failed to parse suggestions:', e);
|
||||
}
|
||||
break;
|
||||
|
||||
default:
|
||||
console.warn('Unknown message type:', response.type);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Rendering Functions
|
||||
// ============================================================================
|
||||
|
||||
function renderChat() {
|
||||
if (state.messageHistory.length === 0) {
|
||||
elements.chatContainer.innerHTML = `
|
||||
<div class="welcome-message">
|
||||
<p>👋 Welcome to TypeDialog AI Configuration Assistant</p>
|
||||
<p>Type a configuration question to begin</p>
|
||||
</div>
|
||||
`;
|
||||
return;
|
||||
}
|
||||
|
||||
const messageHtml = state.messageHistory
|
||||
.map((msg) => {
|
||||
const timeStr = msg.timestamp.toLocaleTimeString();
|
||||
if (msg.isStreaming) {
|
||||
return `
|
||||
<div class="message ${msg.role}">
|
||||
<div class="message-bubble">
|
||||
<div class="streaming">
|
||||
<span></span>
|
||||
<span></span>
|
||||
<span></span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
`;
|
||||
}
|
||||
return `
|
||||
<div class="message ${msg.role}">
|
||||
<div class="message-bubble">${escapeHtml(msg.content)}</div>
|
||||
<div class="message-time">${timeStr}</div>
|
||||
</div>
|
||||
`;
|
||||
})
|
||||
.join('');
|
||||
|
||||
elements.chatContainer.innerHTML = messageHtml;
|
||||
elements.chatContainer.scrollTop = elements.chatContainer.scrollHeight;
|
||||
}
|
||||
|
||||
function renderSuggestions() {
|
||||
if (state.suggestions.length === 0) {
|
||||
elements.suggestionsSection.style.display = 'none';
|
||||
return;
|
||||
}
|
||||
|
||||
elements.suggestionsSection.style.display = 'block';
|
||||
const suggestionsHtml = state.suggestions
|
||||
.map((suggestion) => {
|
||||
const confidencePercent = (suggestion.confidence * 100).toFixed(0);
|
||||
return `
|
||||
<div class="suggestion-card" onclick="insertSuggestion('${escapeAttr(suggestion.field)}', '${escapeAttr(suggestion.value)}')">
|
||||
<div class="suggestion-field">${escapeHtml(suggestion.field)}</div>
|
||||
<div class="suggestion-value">${escapeHtml(suggestion.value)}</div>
|
||||
<div class="suggestion-reasoning">${escapeHtml(suggestion.reasoning)}</div>
|
||||
<div class="suggestion-confidence">
|
||||
<span>${confidencePercent}%</span>
|
||||
<div class="confidence-bar">
|
||||
<div class="confidence-fill" style="width: ${confidencePercent}%"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
`;
|
||||
})
|
||||
.join('');
|
||||
|
||||
elements.suggestionsContainer.innerHTML = suggestionsHtml;
|
||||
}
|
||||
|
||||
function renderConfigPreview() {
|
||||
elements.configSection.style.display = 'block';
|
||||
const codeElement = elements.configPreview.querySelector('code');
|
||||
if (codeElement) {
|
||||
codeElement.textContent = state.generatedConfig;
|
||||
}
|
||||
}
|
||||
|
||||
function updateUI() {
|
||||
const hasConversation = state.conversationId !== null;
|
||||
const hasSchema = state.selectedSchema !== null;
|
||||
|
||||
// Sidebar
|
||||
elements.startBtn.disabled = !hasSchema;
|
||||
elements.copyBtn.disabled = !hasConversation;
|
||||
elements.conversationId.textContent = state.conversationId || '-';
|
||||
|
||||
// Chat
|
||||
elements.messageInput.disabled = !state.isConnected;
|
||||
elements.sendBtn.disabled = !state.isConnected || state.isStreaming;
|
||||
|
||||
// Actions
|
||||
elements.suggestBtn.disabled = !hasConversation;
|
||||
elements.generateBtn.disabled = !hasConversation;
|
||||
elements.clearBtn.disabled = !hasConversation;
|
||||
elements.downloadBtn.disabled = !state.generatedConfig;
|
||||
elements.copyConfigBtn.disabled = !state.generatedConfig;
|
||||
|
||||
// Status
|
||||
updateConnectionStatus();
|
||||
}
|
||||
|
||||
function updateConnectionStatus() {
|
||||
if (state.isConnected) {
|
||||
elements.connectionStatus.textContent = '● Connected';
|
||||
elements.connectionStatus.classList.remove('disconnected');
|
||||
elements.connectionStatus.classList.add('connected');
|
||||
} else {
|
||||
elements.connectionStatus.textContent = '● Disconnected';
|
||||
elements.connectionStatus.classList.remove('connected');
|
||||
elements.connectionStatus.classList.add('disconnected');
|
||||
}
|
||||
}
|
||||
|
||||
function updateRequiredFields(schema) {
|
||||
const fieldsMap = {
|
||||
ServerConfig: ['port', 'host', 'timeout'],
|
||||
DatabaseConfig: ['host', 'port', 'database', 'username'],
|
||||
ApplicationConfig: ['name', 'version', 'debug_mode', 'log_level'],
|
||||
};
|
||||
|
||||
const fields = fieldsMap[schema] || [];
|
||||
if (fields.length === 0) {
|
||||
elements.requiredFields.innerHTML =
|
||||
'<li class="placeholder">No fields for this schema</li>';
|
||||
return;
|
||||
}
|
||||
|
||||
const fieldsHtml = fields
|
||||
.map((field) => `<li class="required">${escapeHtml(field)}</li>`)
|
||||
.join('');
|
||||
elements.requiredFields.innerHTML = fieldsHtml;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Utility Functions
|
||||
// ============================================================================
|
||||
|
||||
function insertSuggestion(field, value) {
|
||||
elements.messageInput.value = `Set ${field} to ${value}`;
|
||||
elements.messageInput.focus();
|
||||
}
|
||||
|
||||
function generateConversationId() {
|
||||
return `conv-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`;
|
||||
}
|
||||
|
||||
function showStatus(message, type = 'info') {
|
||||
elements.statusText.textContent = message;
|
||||
elements.statusText.style.color =
|
||||
type === 'error' ? 'var(--color-danger)' : 'var(--color-text-light)';
|
||||
setTimeout(() => {
|
||||
elements.statusText.textContent = 'Ready';
|
||||
elements.statusText.style.color = 'var(--color-text-light)';
|
||||
}, 3000);
|
||||
}
|
||||
|
||||
function showLoading(show) {
|
||||
if (show) {
|
||||
elements.loadingOverlay.style.display = 'flex';
|
||||
} else {
|
||||
elements.loadingOverlay.style.display = 'none';
|
||||
}
|
||||
}
|
||||
|
||||
function generateConfigPreview(format) {
|
||||
const configs = {
|
||||
json: JSON.stringify(
|
||||
{
|
||||
server: {
|
||||
port: 8080,
|
||||
host: 'localhost',
|
||||
timeout: 30,
|
||||
},
|
||||
database: {
|
||||
host: 'localhost',
|
||||
port: 5432,
|
||||
name: 'myapp',
|
||||
},
|
||||
logging: {
|
||||
level: 'info',
|
||||
format: 'json',
|
||||
},
|
||||
},
|
||||
null,
|
||||
2
|
||||
),
|
||||
yaml: `server:
|
||||
port: 8080
|
||||
host: localhost
|
||||
timeout: 30
|
||||
database:
|
||||
host: localhost
|
||||
port: 5432
|
||||
name: myapp
|
||||
logging:
|
||||
level: info
|
||||
format: json`,
|
||||
toml: `[server]
|
||||
port = 8080
|
||||
host = "localhost"
|
||||
timeout = 30
|
||||
|
||||
[database]
|
||||
host = "localhost"
|
||||
port = 5432
|
||||
name = "myapp"
|
||||
|
||||
[logging]
|
||||
level = "info"
|
||||
format = "json"`,
|
||||
};
|
||||
|
||||
return configs[format] || configs.json;
|
||||
}
|
||||
|
||||
function escapeHtml(text) {
|
||||
const div = document.createElement('div');
|
||||
div.textContent = text;
|
||||
return div.innerHTML;
|
||||
}
|
||||
|
||||
function escapeAttr(text) {
|
||||
return text.replace(/"/g, '"').replace(/'/g, ''');
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Export for testing
|
||||
// ============================================================================
|
||||
|
||||
if (typeof module !== 'undefined' && module.exports) {
|
||||
module.exports = { state, generateConversationId, escapeHtml };
|
||||
}
|
||||
121
crates/typedialog-ai/src/web_ui/static/index.html
Normal file
121
crates/typedialog-ai/src/web_ui/static/index.html
Normal file
@ -0,0 +1,121 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>TypeDialog AI Configuration Assistant</title>
|
||||
<link rel="stylesheet" href="styles.css">
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<!-- Header -->
|
||||
<header class="header">
|
||||
<div class="header-content">
|
||||
<h1>🤖 TypeDialog AI Configuration Assistant</h1>
|
||||
<p class="subtitle">Intelligent configuration generation with LLM assistance</p>
|
||||
</div>
|
||||
</header>
|
||||
|
||||
<!-- Main Layout -->
|
||||
<div class="main-layout">
|
||||
<!-- Sidebar -->
|
||||
<aside class="sidebar">
|
||||
<div class="sidebar-section">
|
||||
<h3>Schema Selection</h3>
|
||||
<select id="schemaSelect" class="schema-select">
|
||||
<option value="">Select a schema...</option>
|
||||
<option value="ServerConfig">Server Configuration</option>
|
||||
<option value="DatabaseConfig">Database Configuration</option>
|
||||
<option value="ApplicationConfig">Application Configuration</option>
|
||||
</select>
|
||||
<button id="startBtn" class="btn btn-primary" disabled>Start Conversation</button>
|
||||
</div>
|
||||
|
||||
<div class="sidebar-section">
|
||||
<h3>Conversation ID</h3>
|
||||
<div id="conversationId" class="conversation-id">-</div>
|
||||
<button id="copyBtn" class="btn btn-secondary" disabled>Copy ID</button>
|
||||
</div>
|
||||
|
||||
<div class="sidebar-section">
|
||||
<h3>Required Fields</h3>
|
||||
<ul id="requiredFields" class="fields-list">
|
||||
<li class="placeholder">Select a schema to see required fields</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<div class="sidebar-section">
|
||||
<h3>Actions</h3>
|
||||
<button id="suggestBtn" class="btn btn-secondary" disabled>Get Suggestions</button>
|
||||
<button id="generateBtn" class="btn btn-secondary" disabled>Generate Config</button>
|
||||
<button id="clearBtn" class="btn btn-danger" disabled>Clear Conversation</button>
|
||||
</div>
|
||||
</aside>
|
||||
|
||||
<!-- Main Content -->
|
||||
<main class="main-content">
|
||||
<!-- Chat Section -->
|
||||
<section class="chat-section">
|
||||
<h2>Conversation</h2>
|
||||
<div id="chatContainer" class="chat-container">
|
||||
<div class="welcome-message">
|
||||
<p>👋 Welcome to TypeDialog AI Configuration Assistant</p>
|
||||
<p>Select a schema on the left to begin</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Input Area -->
|
||||
<div class="input-area">
|
||||
<textarea
|
||||
id="messageInput"
|
||||
class="message-input"
|
||||
placeholder="Ask a configuration question..."
|
||||
disabled
|
||||
></textarea>
|
||||
<button id="sendBtn" class="btn btn-primary" disabled>Send Message</button>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- Suggestions Section -->
|
||||
<section class="suggestions-section" id="suggestionsSection" style="display: none;">
|
||||
<h3>Field Suggestions</h3>
|
||||
<div id="suggestionsContainer" class="suggestions-container">
|
||||
<!-- Suggestions will be populated here -->
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- Configuration Preview Section -->
|
||||
<section class="config-section" id="configSection" style="display: none;">
|
||||
<div class="config-header">
|
||||
<h3>Configuration Preview</h3>
|
||||
<select id="formatSelect" class="format-select">
|
||||
<option value="json">JSON</option>
|
||||
<option value="yaml">YAML</option>
|
||||
<option value="toml">TOML</option>
|
||||
</select>
|
||||
</div>
|
||||
<pre id="configPreview" class="config-preview"><code>// No configuration generated yet</code></pre>
|
||||
<div class="config-actions">
|
||||
<button id="downloadBtn" class="btn btn-secondary" disabled>Download Config</button>
|
||||
<button id="copyConfigBtn" class="btn btn-secondary" disabled>Copy to Clipboard</button>
|
||||
</div>
|
||||
</section>
|
||||
</main>
|
||||
</div>
|
||||
|
||||
<!-- Status Bar -->
|
||||
<footer class="status-bar">
|
||||
<span id="statusText" class="status-text">Ready</span>
|
||||
<span id="connectionStatus" class="connection-status connected">● Connected</span>
|
||||
</footer>
|
||||
</div>
|
||||
|
||||
<!-- Loading Indicator -->
|
||||
<div id="loadingOverlay" class="loading-overlay" style="display: none;">
|
||||
<div class="spinner"></div>
|
||||
<p>Processing...</p>
|
||||
</div>
|
||||
|
||||
<script src="app.js"></script>
|
||||
</body>
|
||||
</html>
|
||||
672
crates/typedialog-ai/src/web_ui/static/styles.css
Normal file
672
crates/typedialog-ai/src/web_ui/static/styles.css
Normal file
@ -0,0 +1,672 @@
|
||||
/* ============================================================================
|
||||
Global Styles
|
||||
============================================================================ */
|
||||
|
||||
* {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
:root {
|
||||
--color-primary: #4f46e5;
|
||||
--color-primary-dark: #4338ca;
|
||||
--color-secondary: #8b5cf6;
|
||||
--color-success: #10b981;
|
||||
--color-danger: #ef4444;
|
||||
--color-warning: #f59e0b;
|
||||
--color-bg: #ffffff;
|
||||
--color-bg-alt: #f9fafb;
|
||||
--color-bg-hover: #f3f4f6;
|
||||
--color-border: #e5e7eb;
|
||||
--color-text: #1f2937;
|
||||
--color-text-light: #6b7280;
|
||||
--color-text-lighter: #9ca3af;
|
||||
|
||||
--font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
|
||||
--font-mono: 'Monaco', 'Menlo', 'Ubuntu Mono', 'Courier New', monospace;
|
||||
|
||||
--shadow-sm: 0 1px 2px 0 rgba(0, 0, 0, 0.05);
|
||||
--shadow-md: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
|
||||
--shadow-lg: 0 10px 15px -3px rgba(0, 0, 0, 0.1);
|
||||
|
||||
--border-radius: 8px;
|
||||
--transition: all 0.3s ease;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: var(--font-family);
|
||||
background: var(--color-bg-alt);
|
||||
color: var(--color-text);
|
||||
line-height: 1.6;
|
||||
}
|
||||
|
||||
/* ============================================================================
|
||||
Layout
|
||||
============================================================================ */
|
||||
|
||||
.container {
|
||||
height: 100vh;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.header {
|
||||
background: linear-gradient(135deg, var(--color-primary) 0%, var(--color-secondary) 100%);
|
||||
color: white;
|
||||
padding: 1.5rem;
|
||||
box-shadow: var(--shadow-md);
|
||||
z-index: 100;
|
||||
}
|
||||
|
||||
.header-content h1 {
|
||||
font-size: 1.875rem;
|
||||
font-weight: 700;
|
||||
margin-bottom: 0.25rem;
|
||||
}
|
||||
|
||||
.subtitle {
|
||||
font-size: 0.875rem;
|
||||
opacity: 0.9;
|
||||
}
|
||||
|
||||
.main-layout {
|
||||
display: flex;
|
||||
flex: 1;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.sidebar {
|
||||
width: 300px;
|
||||
background: var(--color-bg);
|
||||
border-right: 1px solid var(--color-border);
|
||||
overflow-y: auto;
|
||||
padding: 1.5rem;
|
||||
gap: 1.5rem;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.main-content {
|
||||
flex: 1;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
overflow: hidden;
|
||||
padding: 1.5rem;
|
||||
gap: 1rem;
|
||||
}
|
||||
|
||||
.status-bar {
|
||||
background: var(--color-bg);
|
||||
border-top: 1px solid var(--color-border);
|
||||
padding: 0.75rem 1.5rem;
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
font-size: 0.875rem;
|
||||
color: var(--color-text-light);
|
||||
}
|
||||
|
||||
/* ============================================================================
|
||||
Sidebar Components
|
||||
============================================================================ */
|
||||
|
||||
.sidebar-section {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.75rem;
|
||||
}
|
||||
|
||||
.sidebar-section h3 {
|
||||
font-size: 0.875rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text-light);
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.05em;
|
||||
}
|
||||
|
||||
.schema-select,
|
||||
.format-select {
|
||||
padding: 0.625rem;
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: var(--border-radius);
|
||||
font-family: var(--font-family);
|
||||
font-size: 0.875rem;
|
||||
background: var(--color-bg);
|
||||
color: var(--color-text);
|
||||
cursor: pointer;
|
||||
transition: var(--transition);
|
||||
}
|
||||
|
||||
.schema-select:hover,
|
||||
.format-select:hover {
|
||||
border-color: var(--color-primary);
|
||||
}
|
||||
|
||||
.schema-select:focus,
|
||||
.format-select:focus {
|
||||
outline: none;
|
||||
border-color: var(--color-primary);
|
||||
box-shadow: 0 0 0 3px rgba(79, 70, 229, 0.1);
|
||||
}
|
||||
|
||||
.conversation-id {
|
||||
padding: 0.75rem;
|
||||
background: var(--color-bg-alt);
|
||||
border-radius: var(--border-radius);
|
||||
font-family: var(--font-mono);
|
||||
font-size: 0.75rem;
|
||||
word-break: break-all;
|
||||
color: var(--color-text-light);
|
||||
}
|
||||
|
||||
.fields-list {
|
||||
list-style: none;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.5rem;
|
||||
}
|
||||
|
||||
.fields-list li {
|
||||
padding: 0.5rem;
|
||||
background: var(--color-bg-alt);
|
||||
border-radius: var(--border-radius);
|
||||
font-size: 0.875rem;
|
||||
color: var(--color-text);
|
||||
}
|
||||
|
||||
.fields-list li.placeholder {
|
||||
color: var(--color-text-light);
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.fields-list li.required::before {
|
||||
content: '● ';
|
||||
color: var(--color-danger);
|
||||
}
|
||||
|
||||
/* ============================================================================
|
||||
Buttons
|
||||
============================================================================ */
|
||||
|
||||
.btn {
|
||||
padding: 0.625rem 1rem;
|
||||
border: none;
|
||||
border-radius: var(--border-radius);
|
||||
font-family: var(--font-family);
|
||||
font-size: 0.875rem;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
transition: var(--transition);
|
||||
width: 100%;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.btn:disabled {
|
||||
opacity: 0.5;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.btn-primary {
|
||||
background: var(--color-primary);
|
||||
color: white;
|
||||
}
|
||||
|
||||
.btn-primary:hover:not(:disabled) {
|
||||
background: var(--color-primary-dark);
|
||||
box-shadow: var(--shadow-md);
|
||||
}
|
||||
|
||||
.btn-secondary {
|
||||
background: var(--color-bg-alt);
|
||||
color: var(--color-text);
|
||||
border: 1px solid var(--color-border);
|
||||
}
|
||||
|
||||
.btn-secondary:hover:not(:disabled) {
|
||||
background: var(--color-bg-hover);
|
||||
border-color: var(--color-primary);
|
||||
}
|
||||
|
||||
.btn-danger {
|
||||
background: var(--color-danger);
|
||||
color: white;
|
||||
}
|
||||
|
||||
.btn-danger:hover:not(:disabled) {
|
||||
background: #dc2626;
|
||||
box-shadow: var(--shadow-md);
|
||||
}
|
||||
|
||||
/* ============================================================================
|
||||
Chat Section
|
||||
============================================================================ */
|
||||
|
||||
.chat-section {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1rem;
|
||||
flex: 1;
|
||||
min-height: 0;
|
||||
}
|
||||
|
||||
.chat-section h2 {
|
||||
font-size: 1.125rem;
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
|
||||
.chat-container {
|
||||
flex: 1;
|
||||
background: var(--color-bg);
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: var(--border-radius);
|
||||
padding: 1rem;
|
||||
overflow-y: auto;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1rem;
|
||||
}
|
||||
|
||||
.welcome-message {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.5rem;
|
||||
text-align: center;
|
||||
color: var(--color-text-light);
|
||||
padding: 2rem;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.welcome-message p:first-child {
|
||||
font-size: 1.5rem;
|
||||
}
|
||||
|
||||
.message {
|
||||
display: flex;
|
||||
gap: 0.75rem;
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
|
||||
.message.user {
|
||||
justify-content: flex-end;
|
||||
}
|
||||
|
||||
.message.assistant {
|
||||
justify-content: flex-start;
|
||||
}
|
||||
|
||||
.message-bubble {
|
||||
max-width: 70%;
|
||||
padding: 0.75rem 1rem;
|
||||
border-radius: var(--border-radius);
|
||||
word-wrap: break-word;
|
||||
line-height: 1.5;
|
||||
}
|
||||
|
||||
.message.user .message-bubble {
|
||||
background: var(--color-primary);
|
||||
color: white;
|
||||
}
|
||||
|
||||
.message.assistant .message-bubble {
|
||||
background: var(--color-bg-alt);
|
||||
color: var(--color-text);
|
||||
border: 1px solid var(--color-border);
|
||||
}
|
||||
|
||||
.message-time {
|
||||
font-size: 0.75rem;
|
||||
color: var(--color-text-lighter);
|
||||
padding: 0 0.5rem;
|
||||
align-self: flex-end;
|
||||
}
|
||||
|
||||
.streaming {
|
||||
display: flex;
|
||||
gap: 0.25rem;
|
||||
}
|
||||
|
||||
.streaming span {
|
||||
width: 0.5rem;
|
||||
height: 0.5rem;
|
||||
background: var(--color-text-light);
|
||||
border-radius: 50%;
|
||||
animation: bounce 1.4s infinite;
|
||||
}
|
||||
|
||||
.streaming span:nth-child(2) {
|
||||
animation-delay: 0.2s;
|
||||
}
|
||||
|
||||
.streaming span:nth-child(3) {
|
||||
animation-delay: 0.4s;
|
||||
}
|
||||
|
||||
@keyframes bounce {
|
||||
0%, 80%, 100% {
|
||||
opacity: 0.3;
|
||||
transform: translateY(0);
|
||||
}
|
||||
40% {
|
||||
opacity: 1;
|
||||
transform: translateY(-0.5rem);
|
||||
}
|
||||
}
|
||||
|
||||
.input-area {
|
||||
display: flex;
|
||||
gap: 0.75rem;
|
||||
}
|
||||
|
||||
.message-input {
|
||||
flex: 1;
|
||||
padding: 0.75rem;
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: var(--border-radius);
|
||||
font-family: var(--font-family);
|
||||
font-size: 0.875rem;
|
||||
resize: none;
|
||||
max-height: 120px;
|
||||
color: var(--color-text);
|
||||
}
|
||||
|
||||
.message-input:focus {
|
||||
outline: none;
|
||||
border-color: var(--color-primary);
|
||||
box-shadow: 0 0 0 3px rgba(79, 70, 229, 0.1);
|
||||
}
|
||||
|
||||
.message-input:disabled {
|
||||
background: var(--color-bg-alt);
|
||||
color: var(--color-text-light);
|
||||
}
|
||||
|
||||
/* ============================================================================
|
||||
Suggestions Section
|
||||
============================================================================ */
|
||||
|
||||
.suggestions-section {
|
||||
background: var(--color-bg);
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: var(--border-radius);
|
||||
padding: 1rem;
|
||||
margin-top: 1rem;
|
||||
}
|
||||
|
||||
.suggestions-section h3 {
|
||||
margin-bottom: 0.75rem;
|
||||
}
|
||||
|
||||
.suggestions-container {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fill, minmax(200px, 1fr));
|
||||
gap: 1rem;
|
||||
}
|
||||
|
||||
.suggestion-card {
|
||||
background: var(--color-bg-alt);
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: var(--border-radius);
|
||||
padding: 1rem;
|
||||
cursor: pointer;
|
||||
transition: var(--transition);
|
||||
}
|
||||
|
||||
.suggestion-card:hover {
|
||||
border-color: var(--color-primary);
|
||||
box-shadow: var(--shadow-md);
|
||||
transform: translateY(-2px);
|
||||
}
|
||||
|
||||
.suggestion-field {
|
||||
font-weight: 600;
|
||||
color: var(--color-primary);
|
||||
font-size: 0.875rem;
|
||||
text-transform: uppercase;
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
|
||||
.suggestion-value {
|
||||
font-family: var(--font-mono);
|
||||
font-size: 0.875rem;
|
||||
padding: 0.5rem;
|
||||
background: var(--color-bg);
|
||||
border-radius: 4px;
|
||||
margin-bottom: 0.5rem;
|
||||
word-break: break-all;
|
||||
}
|
||||
|
||||
.suggestion-reasoning {
|
||||
font-size: 0.75rem;
|
||||
color: var(--color-text-light);
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
|
||||
.suggestion-confidence {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
font-size: 0.75rem;
|
||||
color: var(--color-text-light);
|
||||
gap: 0.5rem;
|
||||
}
|
||||
|
||||
.confidence-bar {
|
||||
flex: 1;
|
||||
height: 4px;
|
||||
background: var(--color-border);
|
||||
border-radius: 2px;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.confidence-fill {
|
||||
height: 100%;
|
||||
background: var(--color-success);
|
||||
border-radius: 2px;
|
||||
}
|
||||
|
||||
/* ============================================================================
|
||||
Configuration Section
|
||||
============================================================================ */
|
||||
|
||||
.config-section {
|
||||
background: var(--color-bg);
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: var(--border-radius);
|
||||
padding: 1rem;
|
||||
margin-top: 1rem;
|
||||
}
|
||||
|
||||
.config-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.config-header h3 {
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.format-select {
|
||||
padding: 0.5rem;
|
||||
width: 150px;
|
||||
}
|
||||
|
||||
.config-preview {
|
||||
background: var(--color-text);
|
||||
color: #e0e0e0;
|
||||
padding: 1rem;
|
||||
border-radius: var(--border-radius);
|
||||
overflow-x: auto;
|
||||
font-family: var(--font-mono);
|
||||
font-size: 0.875rem;
|
||||
line-height: 1.5;
|
||||
max-height: 300px;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.config-preview code {
|
||||
font-family: var(--font-mono);
|
||||
}
|
||||
|
||||
.config-actions {
|
||||
display: flex;
|
||||
gap: 0.75rem;
|
||||
}
|
||||
|
||||
.config-actions .btn {
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
/* ============================================================================
|
||||
Status & Connection
|
||||
============================================================================ */
|
||||
|
||||
.status-text {
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.connection-status {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
padding: 0.5rem 1rem;
|
||||
border-radius: 999px;
|
||||
background: var(--color-bg-alt);
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
.connection-status.connected {
|
||||
color: var(--color-success);
|
||||
}
|
||||
|
||||
.connection-status.disconnected {
|
||||
color: var(--color-danger);
|
||||
}
|
||||
|
||||
/* ============================================================================
|
||||
Loading Overlay
|
||||
============================================================================ */
|
||||
|
||||
.loading-overlay {
|
||||
position: fixed;
|
||||
top: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
background: rgba(0, 0, 0, 0.5);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
z-index: 1000;
|
||||
gap: 1rem;
|
||||
}
|
||||
|
||||
.spinner {
|
||||
width: 40px;
|
||||
height: 40px;
|
||||
border: 4px solid rgba(255, 255, 255, 0.3);
|
||||
border-top: 4px solid white;
|
||||
border-radius: 50%;
|
||||
animation: spin 1s linear infinite;
|
||||
}
|
||||
|
||||
@keyframes spin {
|
||||
to {
|
||||
transform: rotate(360deg);
|
||||
}
|
||||
}
|
||||
|
||||
.loading-overlay p {
|
||||
color: white;
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
/* ============================================================================
|
||||
Responsive Design
|
||||
============================================================================ */
|
||||
|
||||
@media (max-width: 1024px) {
|
||||
.main-layout {
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.sidebar {
|
||||
width: 100%;
|
||||
max-height: 300px;
|
||||
flex-direction: row;
|
||||
padding: 1rem;
|
||||
overflow-x: auto;
|
||||
border-right: none;
|
||||
border-bottom: 1px solid var(--color-border);
|
||||
}
|
||||
|
||||
.sidebar-section {
|
||||
min-width: 250px;
|
||||
}
|
||||
|
||||
.main-content {
|
||||
padding: 1rem;
|
||||
}
|
||||
|
||||
.suggestions-container {
|
||||
grid-template-columns: repeat(auto-fill, minmax(150px, 1fr));
|
||||
}
|
||||
|
||||
.message-bubble {
|
||||
max-width: 85%;
|
||||
}
|
||||
}
|
||||
|
||||
@media (max-width: 640px) {
|
||||
.header-content h1 {
|
||||
font-size: 1.5rem;
|
||||
}
|
||||
|
||||
.sidebar {
|
||||
flex-direction: column;
|
||||
overflow-x: visible;
|
||||
max-height: none;
|
||||
border-bottom: 1px solid var(--color-border);
|
||||
}
|
||||
|
||||
.sidebar-section {
|
||||
min-width: unset;
|
||||
}
|
||||
|
||||
.suggestions-container {
|
||||
grid-template-columns: 1fr;
|
||||
}
|
||||
|
||||
.message-bubble {
|
||||
max-width: 95%;
|
||||
}
|
||||
|
||||
.btn {
|
||||
font-size: 0.875rem;
|
||||
padding: 0.5rem;
|
||||
}
|
||||
}
|
||||
|
||||
/* ============================================================================
|
||||
Scrollbar Styling
|
||||
============================================================================ */
|
||||
|
||||
::-webkit-scrollbar {
|
||||
width: 8px;
|
||||
height: 8px;
|
||||
}
|
||||
|
||||
::-webkit-scrollbar-track {
|
||||
background: var(--color-bg-alt);
|
||||
}
|
||||
|
||||
::-webkit-scrollbar-thumb {
|
||||
background: var(--color-border);
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
::-webkit-scrollbar-thumb:hover {
|
||||
background: var(--color-text-light);
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user