chore: add typedialog-ai AI with RAG adn knowledge graph

This commit is contained in:
Jesús Pérez 2025-12-24 03:21:01 +00:00
parent 34508cddf4
commit 01980c9b8d
Signed by: jesus
GPG Key ID: 9F243E355E0BC939
35 changed files with 7957 additions and 0 deletions

View File

@ -0,0 +1,62 @@
[package]
name = "typedialog-ai"
version.workspace = true
edition.workspace = true
authors.workspace = true
repository.workspace = true
license.workspace = true
description = "AI-powered configuration assistant backend and microservice for TypeDialog"
[dependencies]
# Internal
typedialog-core = { path = "../typedialog-core", features = ["ai_backend"] }
# Workspace dependencies (shared with other crates)
tokio = { workspace = true }
async-trait = { workspace = true }
futures = { workspace = true }
surrealdb = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
serde_yaml = { workspace = true }
toml = { workspace = true }
clap = { workspace = true }
dialoguer = { workspace = true }
colored = { workspace = true }
uuid = { workspace = true }
chrono = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
thiserror = { workspace = true }
anyhow = { workspace = true }
dirs = { workspace = true }
# Web and HTTP dependencies (now aligned with workspace versions)
# Code migrated to support workspace versions:
# - axum: Upgraded from 0.7 to 0.8.8 (WebSocket Message::Text now uses Utf8Bytes)
# - reqwest: Using workspace 0.12 (streaming API compatible)
# - tower/tower-http: Aligned with axum 0.8.8
reqwest = { workspace = true, features = ["json", "stream"] }
axum = { workspace = true, features = ["ws"] }
tower = { workspace = true }
tower-http = { workspace = true, features = ["cors", "trace"] }
[features]
default = ["openai"]
openai = []
anthropic = []
ollama = []
all-providers = ["openai", "anthropic", "ollama"]
[lib]
name = "typedialog_ai"
path = "src/lib.rs"
[[bin]]
name = "typedialog-ai"
path = "src/main.rs"
[package.metadata.binstall]
pkg-url = "{ repo }/releases/download/v{ version }/typedialog-{ target }.tar.gz"
bin-dir = "bin/{ bin }"
pkg-fmt = "tgz"

View File

@ -0,0 +1,131 @@
//! API error handling
use super::types::ErrorResponse;
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
/// API error type
#[derive(Debug)]
pub enum ApiError {
/// Conversation not found
ConversationNotFound(String),
/// Invalid request
InvalidRequest(String),
/// Database operation failed
DatabaseError(String),
/// LLM operation failed
LlmError(String),
/// Schema analysis failed
SchemaError(String),
/// Configuration generation failed
GenerationError(String),
/// Internal server error
InternalError(String),
}
impl ApiError {
/// Get error code for client handling
fn code(&self) -> &'static str {
match self {
ApiError::ConversationNotFound(_) => "NOT_FOUND",
ApiError::InvalidRequest(_) => "INVALID_REQUEST",
ApiError::DatabaseError(_) => "DATABASE_ERROR",
ApiError::LlmError(_) => "LLM_ERROR",
ApiError::SchemaError(_) => "SCHEMA_ERROR",
ApiError::GenerationError(_) => "GENERATION_ERROR",
ApiError::InternalError(_) => "INTERNAL_ERROR",
}
}
/// Get HTTP status code
fn status(&self) -> StatusCode {
match self {
ApiError::ConversationNotFound(_) => StatusCode::NOT_FOUND,
ApiError::InvalidRequest(_) => StatusCode::BAD_REQUEST,
ApiError::DatabaseError(_) => StatusCode::INTERNAL_SERVER_ERROR,
ApiError::LlmError(_) => StatusCode::INTERNAL_SERVER_ERROR,
ApiError::SchemaError(_) => StatusCode::BAD_REQUEST,
ApiError::GenerationError(_) => StatusCode::INTERNAL_SERVER_ERROR,
ApiError::InternalError(_) => StatusCode::INTERNAL_SERVER_ERROR,
}
}
/// Get error message
fn message(&self) -> String {
match self {
ApiError::ConversationNotFound(id) => format!("Conversation '{}' not found", id),
ApiError::InvalidRequest(msg) => msg.clone(),
ApiError::DatabaseError(msg) => format!("Database error: {}", msg),
ApiError::LlmError(msg) => format!("LLM error: {}", msg),
ApiError::SchemaError(msg) => format!("Schema error: {}", msg),
ApiError::GenerationError(msg) => format!("Generation error: {}", msg),
ApiError::InternalError(msg) => format!("Internal error: {}", msg),
}
}
}
impl IntoResponse for ApiError {
fn into_response(self) -> Response {
let status = self.status();
let error_response = ErrorResponse::new(self.message(), self.code());
(status, Json(error_response)).into_response()
}
}
impl From<anyhow::Error> for ApiError {
fn from(err: anyhow::Error) -> Self {
ApiError::InternalError(err.to_string())
}
}
impl From<String> for ApiError {
fn from(err: String) -> Self {
ApiError::InternalError(err)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_conversation_not_found_code() {
let err = ApiError::ConversationNotFound("123".to_string());
assert_eq!(err.code(), "NOT_FOUND");
}
#[test]
fn test_conversation_not_found_status() {
let err = ApiError::ConversationNotFound("123".to_string());
assert_eq!(err.status(), StatusCode::NOT_FOUND);
}
#[test]
fn test_invalid_request_code() {
let err = ApiError::InvalidRequest("bad data".to_string());
assert_eq!(err.code(), "INVALID_REQUEST");
}
#[test]
fn test_llm_error_code() {
let err = ApiError::LlmError("api failed".to_string());
assert_eq!(err.code(), "LLM_ERROR");
assert_eq!(err.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[test]
fn test_error_message() {
let err = ApiError::ConversationNotFound("conv-123".to_string());
assert!(err.message().contains("conv-123"));
}
}

View File

@ -0,0 +1,15 @@
//! REST API module
//!
//! Provides HTTP and WebSocket endpoints for the AI configuration assistant service.
pub mod error;
pub mod rest;
pub mod types;
pub mod websocket;
#[allow(unused_imports)]
pub use error::ApiError;
#[allow(unused_imports)]
pub use rest::{create_router, AppState};
#[allow(unused_imports)]
pub use websocket::{WsMessage, WsResponse};

View File

@ -0,0 +1,311 @@
//! REST API handlers for TypeDialog AI Service
use axum::{
extract::{Path, State},
http::StatusCode,
routing::{delete, get, post},
Json, Router,
};
use chrono::Utc;
use std::sync::{Arc, Mutex};
use super::{error::ApiError, types::*, websocket};
use crate::assistant::ConfigAssistant;
use crate::storage::SurrealDbClient;
/// Shared application state
#[derive(Clone)]
pub struct AppState {
/// Database client
pub db: Arc<SurrealDbClient>,
/// Currently active assistants by conversation ID
pub assistants: Arc<Mutex<std::collections::HashMap<String, ConfigAssistant>>>,
/// Server start time for uptime tracking
pub start_time: std::time::Instant,
}
/// Create Axum router with all routes
pub fn create_router(state: AppState) -> Router {
use crate::web_ui;
Router::new()
// Web UI
.route("/", get(web_ui::index))
.route("/index.html", get(web_ui::index))
.route("/styles.css", get(web_ui::styles))
.route("/app.js", get(web_ui::app))
// Health check
.route("/health", get(health_check))
// Conversation management
.route("/conversations", post(create_conversation))
.route("/conversations/{id}", get(get_conversation))
.route("/conversations/{id}", delete(delete_conversation))
// Message handling
.route("/conversations/{id}/messages", post(send_message))
// Configuration generation
.route("/conversations/{id}/generate", post(generate_config))
// Suggestions
.route(
"/conversations/{id}/suggestions/{field}",
get(get_suggestions),
)
// WebSocket streaming
.route("/ws/{id}", axum::routing::get(websocket::handle_websocket))
.with_state(state)
}
// ============================================================================
// HEALTH CHECK
// ============================================================================
/// Health check endpoint
pub async fn health_check(State(state): State<AppState>) -> Json<HealthResponse> {
let uptime = state.start_time.elapsed().as_secs();
Json(HealthResponse {
status: "healthy".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
uptime,
})
}
// ============================================================================
// CONVERSATION MANAGEMENT
// ============================================================================
/// Create a new conversation
///
/// # Request
/// ```json
/// {"schema_id": "my-schema"}
/// ```
///
/// # Response
/// ```json
/// {
/// "conversation_id": "uuid",
/// "schema_id": "my-schema",
/// "message": "Conversation started"
/// }
/// ```
pub async fn create_conversation(
State(state): State<AppState>,
Json(req): Json<CreateConversationRequest>,
) -> Result<(StatusCode, Json<CreateConversationResponse>), ApiError> {
tracing::info!(schema_id = %req.schema_id, "Creating new conversation");
// Create conversation in database
let conv_id = state
.db
.create_conversation(&req.schema_id)
.await
.map_err(|e| ApiError::DatabaseError(e.to_string()))?;
Ok((
StatusCode::CREATED,
Json(CreateConversationResponse {
conversation_id: conv_id,
schema_id: req.schema_id,
message: "Conversation created successfully".to_string(),
}),
))
}
/// Get conversation information
pub async fn get_conversation(
State(_state): State<AppState>,
Path(id): Path<String>,
) -> Result<Json<ConversationInfo>, ApiError> {
tracing::debug!(conversation_id = %id, "Getting conversation info");
// TODO: Load conversation from database
Ok(Json(ConversationInfo {
id: id.clone(),
schema_id: "schema".to_string(),
status: "active".to_string(),
created_at: Utc::now().to_rfc3339(),
updated_at: Utc::now().to_rfc3339(),
}))
}
/// Delete a conversation
pub async fn delete_conversation(
State(_state): State<AppState>,
Path(id): Path<String>,
) -> Result<StatusCode, ApiError> {
tracing::info!(conversation_id = %id, "Deleting conversation");
// TODO: Delete conversation from database
Ok(StatusCode::NO_CONTENT)
}
// ============================================================================
// MESSAGE HANDLING
// ============================================================================
/// Send a message to the assistant
///
/// # Request
/// ```json
/// {"message": "What port should I use?"}
/// ```
///
/// # Response
/// ```json
/// {
/// "text": "For a web server, port 8080 is commonly used...",
/// "suggestions": [
/// {
/// "field": "port",
/// "value": "8080",
/// "reasoning": "Common web server port",
/// "confidence": 0.92
/// }
/// ],
/// "message_id": "msg-uuid",
/// "rag_context": "..."
/// }
/// ```
pub async fn send_message(
State(_state): State<AppState>,
Path(id): Path<String>,
Json(req): Json<SendMessageRequest>,
) -> Result<Json<SendMessageResponse>, ApiError> {
tracing::debug!(conversation_id = %id, message_len = req.message.len(), "Processing message");
if req.message.trim().is_empty() {
return Err(ApiError::InvalidRequest(
"Message cannot be empty".to_string(),
));
}
// TODO: Load assistant from state, call send_message(), return response
Err(ApiError::InternalError(
"Assistant not initialized".to_string(),
))
}
// ============================================================================
// CONFIGURATION GENERATION
// ============================================================================
/// Generate configuration from conversation
///
/// # Request
/// ```json
/// {"format": "json"}
/// ```
///
/// # Response
/// ```json
/// {
/// "content": "{\"port\": 8080, ...}",
/// "format": "json",
/// "is_valid": true,
/// "errors": []
/// }
/// ```
pub async fn generate_config(
State(_state): State<AppState>,
Path(id): Path<String>,
Json(req): Json<GenerateConfigRequest>,
) -> Result<Json<GenerateConfigResponse>, ApiError> {
tracing::info!(conversation_id = %id, format = %req.format, "Generating configuration");
// Validate format
if !matches!(req.format.as_str(), "json" | "yaml" | "toml") {
return Err(ApiError::InvalidRequest(format!(
"Invalid format '{}'. Must be 'json', 'yaml', or 'toml'",
req.format
)));
}
// TODO: Load assistant from state, call generate_config(), return response
Err(ApiError::InternalError(
"Assistant not initialized".to_string(),
))
}
// ============================================================================
// SUGGESTIONS
// ============================================================================
/// Get field suggestions
///
/// # Response
/// ```json
/// {
/// "field": "port",
/// "suggestions": [
/// {
/// "field": "port",
/// "value": "8080",
/// "reasoning": "Found in examples",
/// "confidence": 0.85
/// }
/// ]
/// }
/// ```
pub async fn get_suggestions(
State(_state): State<AppState>,
Path((id, field)): Path<(String, String)>,
) -> Result<Json<SuggestionsResponse>, ApiError> {
tracing::debug!(conversation_id = %id, field = %field, "Getting suggestions");
// TODO: Load assistant from state, call suggest_values(), return response
Err(ApiError::InternalError(
"Assistant not initialized".to_string(),
))
}
#[cfg(test)]
mod tests {
use super::*;
async fn create_test_state() -> Result<AppState, Box<dyn std::error::Error>> {
let db = Arc::new(SurrealDbClient::new("memory://", "default", "test").await?);
Ok(AppState {
db,
assistants: Arc::new(Mutex::new(std::collections::HashMap::new())),
start_time: std::time::Instant::now(),
})
}
#[tokio::test]
async fn test_router_creation() {
let state = create_test_state().await.unwrap();
let _router = create_router(state);
// Router created successfully
}
#[test]
fn test_create_conversation_request_validation() {
let req = CreateConversationRequest {
schema_id: "test-schema".to_string(),
};
assert!(!req.schema_id.is_empty());
}
#[test]
fn test_generate_config_request_format_validation() {
let valid_formats = vec!["json", "yaml", "toml"];
for format in valid_formats {
let req = GenerateConfigRequest {
format: format.to_string(),
};
assert!(matches!(req.format.as_str(), "json" | "yaml" | "toml"));
}
}
#[test]
fn test_send_message_request_serialization() {
let req = SendMessageRequest {
message: "Test message".to_string(),
};
let json = serde_json::to_string(&req).unwrap();
let deserialized: SendMessageRequest = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.message, "Test message");
}
}

View File

@ -0,0 +1,220 @@
//! API request and response types
use crate::assistant::{AssistantResponse, FieldSuggestion, GeneratedConfig};
use serde::{Deserialize, Serialize};
/// Request to start a new conversation
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateConversationRequest {
/// Schema identifier or name
pub schema_id: String,
}
/// Response when conversation is created
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateConversationResponse {
/// Unique conversation identifier
pub conversation_id: String,
/// Schema identifier
pub schema_id: String,
/// Human-readable message
pub message: String,
}
/// Request to send a message to the assistant
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SendMessageRequest {
/// User message text
pub message: String,
}
/// Response from sending a message
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SendMessageResponse {
/// Assistant's text response
pub text: String,
/// Field suggestions extracted from response
pub suggestions: Vec<FieldSuggestion>,
/// Message ID for tracking
pub message_id: String,
/// RAG context used for generation
pub rag_context: Option<String>,
}
impl From<AssistantResponse> for SendMessageResponse {
fn from(resp: AssistantResponse) -> Self {
SendMessageResponse {
text: resp.text,
suggestions: resp.suggestions,
message_id: resp.message_id,
rag_context: resp.rag_context,
}
}
}
/// Request to get field suggestions
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SuggestionsRequest {
/// Field name to get suggestions for
pub field: String,
}
/// Response with field suggestions
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SuggestionsResponse {
/// Field name
pub field: String,
/// List of suggestions
pub suggestions: Vec<FieldSuggestion>,
}
/// Request to generate final configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GenerateConfigRequest {
/// Output format: "json", "yaml", or "toml"
pub format: String,
}
/// Response with generated configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GenerateConfigResponse {
/// Generated configuration content
pub content: String,
/// Configuration format
pub format: String,
/// Whether configuration is valid
pub is_valid: bool,
/// Validation errors if any
pub errors: Vec<String>,
}
impl From<GeneratedConfig> for GenerateConfigResponse {
fn from(config: GeneratedConfig) -> Self {
GenerateConfigResponse {
content: config.content,
format: config.format,
is_valid: config.is_valid,
errors: config.errors,
}
}
}
/// Conversation metadata
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConversationInfo {
/// Conversation ID
pub id: String,
/// Schema ID
pub schema_id: String,
/// Conversation status
pub status: String,
/// Created timestamp
pub created_at: String,
/// Updated timestamp
pub updated_at: String,
}
/// Health check response
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HealthResponse {
/// Service status
pub status: String,
/// Service version
pub version: String,
/// Uptime in seconds
pub uptime: u64,
}
/// Error response
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorResponse {
/// Error message
pub error: String,
/// Error code (for clients to handle specific errors)
pub code: String,
/// Optional additional details
pub details: Option<String>,
}
impl ErrorResponse {
/// Create a new error response
pub fn new(error: impl Into<String>, code: impl Into<String>) -> Self {
ErrorResponse {
error: error.into(),
code: code.into(),
details: None,
}
}
/// Add details to error response
pub fn with_details(mut self, details: impl Into<String>) -> Self {
self.details = Some(details.into());
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_conversation_request_serialization() {
let req = CreateConversationRequest {
schema_id: "test-schema".to_string(),
};
let json = serde_json::to_string(&req).unwrap();
let deserialized: CreateConversationRequest = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.schema_id, "test-schema");
}
#[test]
fn test_send_message_request_serialization() {
let req = SendMessageRequest {
message: "What port should I use?".to_string(),
};
let json = serde_json::to_string(&req).unwrap();
let deserialized: SendMessageRequest = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.message, "What port should I use?");
}
#[test]
fn test_error_response_creation() {
let error = ErrorResponse::new("Invalid request", "INVALID_REQUEST");
assert_eq!(error.error, "Invalid request");
assert_eq!(error.code, "INVALID_REQUEST");
}
#[test]
fn test_error_response_with_details() {
let error = ErrorResponse::new("Invalid request", "INVALID_REQUEST")
.with_details("Field 'message' is required");
assert!(error.details.is_some());
}
#[test]
fn test_generate_config_request_serialization() {
let req = GenerateConfigRequest {
format: "json".to_string(),
};
let json = serde_json::to_string(&req).unwrap();
let deserialized: GenerateConfigRequest = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.format, "json");
}
}

View File

@ -0,0 +1,333 @@
//! WebSocket streaming for real-time LLM responses
//!
//! Provides token-by-token streaming of assistant responses via WebSocket.
use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
Path, State,
},
response::IntoResponse,
};
use serde::{Deserialize, Serialize};
use serde_json::json;
use super::rest::AppState;
use crate::assistant::ConfigAssistant;
/// WebSocket message from client
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WsMessage {
/// Message type: "message", "suggestion", "generate", etc.
pub r#type: String,
/// Message content (user message or other data)
pub content: String,
/// Additional data (e.g., format for config generation)
pub data: Option<serde_json::Value>,
}
/// WebSocket response sent to client
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WsResponse {
/// Response type: "start", "chunk", "end", "error", "suggestion", etc.
pub r#type: String,
/// Response content
pub content: String,
/// Optional metadata (e.g., message_id, confidence)
pub metadata: Option<serde_json::Value>,
}
impl WsResponse {
/// Create a streaming chunk response
pub fn chunk(content: impl Into<String>) -> Self {
WsResponse {
r#type: "chunk".to_string(),
content: content.into(),
metadata: None,
}
}
/// Create a start response
pub fn start() -> Self {
WsResponse {
r#type: "start".to_string(),
content: String::new(),
metadata: None,
}
}
/// Create an end response with metadata
pub fn end(metadata: Option<serde_json::Value>) -> Self {
WsResponse {
r#type: "end".to_string(),
content: String::new(),
metadata,
}
}
/// Create an error response
pub fn error(error: impl Into<String>) -> Self {
WsResponse {
r#type: "error".to_string(),
content: error.into(),
metadata: None,
}
}
/// Serialize to JSON
pub fn to_json(&self) -> String {
serde_json::to_string(self).unwrap_or_else(|_| {
serde_json::to_string(&WsResponse::error("Serialization failed")).unwrap()
})
}
}
/// Handle WebSocket upgrade and streaming
///
/// Accepts a WebSocket connection and streams LLM responses in real-time.
/// The client should send JSON messages with:
/// - type: "message", "generate", or "suggestions"
/// - content: the user input
/// - data: optional additional data (e.g., format for generation)
pub async fn handle_websocket(
ws: WebSocketUpgrade,
Path(conversation_id): Path<String>,
State(state): State<AppState>,
) -> impl IntoResponse {
ws.on_upgrade(|socket| handle_socket(socket, conversation_id, state))
}
async fn handle_socket(mut socket: WebSocket, conversation_id: String, state: AppState) {
// Get or create assistant for this conversation
let mut assistant_opt = None;
while let Some(msg_result) = socket.recv().await {
match msg_result {
Ok(Message::Text(text)) => {
// Parse incoming message
match serde_json::from_str::<WsMessage>(&text) {
Ok(ws_msg) => {
// Handle message based on type
match ws_msg.r#type.as_str() {
"message" => {
handle_message_type(
&mut socket,
&mut assistant_opt,
&conversation_id,
&state,
&ws_msg,
)
.await;
}
"generate" => {
handle_generate_type(
&mut socket,
&mut assistant_opt,
&conversation_id,
&state,
&ws_msg,
)
.await;
}
"suggestions" => {
handle_suggestions_type(
&mut socket,
&mut assistant_opt,
&conversation_id,
&state,
&ws_msg,
)
.await;
}
_ => {
let error_resp = WsResponse::error(format!(
"Unknown message type: {}",
ws_msg.r#type
));
drop(
socket
.send(Message::Text(error_resp.to_json().into()))
.await,
);
}
}
}
Err(e) => {
let error_resp =
WsResponse::error(format!("Failed to parse message: {}", e));
drop(
socket
.send(Message::Text(error_resp.to_json().into()))
.await,
);
}
}
}
Ok(Message::Close(_)) => {
tracing::debug!(conv_id = %conversation_id, "WebSocket closed by client");
break;
}
Ok(_) => {
// Ignore other message types
}
Err(e) => {
tracing::error!(conv_id = %conversation_id, error = %e, "WebSocket error");
break;
}
}
}
}
async fn handle_message_type(
socket: &mut WebSocket,
_assistant_opt: &mut Option<ConfigAssistant>,
_conversation_id: &str,
_state: &AppState,
ws_msg: &WsMessage,
) {
// Send start marker
let start_resp = WsResponse::start();
if socket
.send(Message::Text(start_resp.to_json().into()))
.await
.is_err()
{
return;
}
// Get or create assistant (for now, just send the message back with chunks)
// In a full implementation, this would get the assistant from state
// and call stream_message() on it
// Simulate streaming response by splitting the content into words
let words: Vec<&str> = ws_msg.content.split_whitespace().collect();
for (i, word) in words.iter().enumerate() {
let chunk_resp = WsResponse::chunk(if i == 0 {
word.to_string()
} else {
format!(" {}", word)
});
if socket
.send(Message::Text(chunk_resp.to_json().into()))
.await
.is_err()
{
return;
}
// Small delay to simulate token streaming
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
}
// Send end marker
let end_resp = WsResponse::end(Some(json!({
"type": "message",
"status": "complete"
})));
drop(socket.send(Message::Text(end_resp.to_json().into())).await);
}
async fn handle_generate_type(
socket: &mut WebSocket,
_assistant_opt: &mut Option<ConfigAssistant>,
_conversation_id: &str,
_state: &AppState,
_ws_msg: &WsMessage,
) {
// Send generation response
let resp = WsResponse::error("Configuration generation not yet implemented via WebSocket");
drop(socket.send(Message::Text(resp.to_json().into())).await);
}
async fn handle_suggestions_type(
socket: &mut WebSocket,
_assistant_opt: &mut Option<ConfigAssistant>,
_conversation_id: &str,
_state: &AppState,
ws_msg: &WsMessage,
) {
// Get field name from content
let field_name = &ws_msg.content;
// Send suggestions response with mock data
let suggestions = json!([
{
"field": field_name,
"value": "example_value",
"reasoning": "Found in examples",
"confidence": 0.85
}
]);
let resp = WsResponse {
r#type: "suggestions".to_string(),
content: serde_json::to_string(&suggestions).unwrap_or_default(),
metadata: Some(json!({"field": field_name})),
};
drop(socket.send(Message::Text(resp.to_json().into())).await);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ws_message_serialization() {
let msg = WsMessage {
r#type: "message".to_string(),
content: "What port?".to_string(),
data: None,
};
let json = serde_json::to_string(&msg).unwrap();
let deserialized: WsMessage = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.r#type, "message");
}
#[test]
fn test_ws_response_chunk() {
let resp = WsResponse::chunk("port");
assert_eq!(resp.r#type, "chunk");
assert_eq!(resp.content, "port");
}
#[test]
fn test_ws_response_start() {
let resp = WsResponse::start();
assert_eq!(resp.r#type, "start");
}
#[test]
fn test_ws_response_end() {
let resp = WsResponse::end(None);
assert_eq!(resp.r#type, "end");
}
#[test]
fn test_ws_response_error() {
let resp = WsResponse::error("Something failed");
assert_eq!(resp.r#type, "error");
assert!(resp.content.contains("failed"));
}
#[test]
fn test_ws_response_to_json() {
let resp = WsResponse::chunk("test");
let json = resp.to_json();
assert!(json.contains("chunk"));
assert!(json.contains("test"));
}
#[test]
fn test_ws_response_with_metadata() {
let metadata = Some(serde_json::json!({"message_id": "123"}));
let resp = WsResponse::end(metadata);
assert!(resp.metadata.is_some());
}
}

View File

@ -0,0 +1,664 @@
#![allow(dead_code)]
//! AI Configuration Assistant Engine
//!
//! Core conversation loop integrating:
//! - LLM provider (OpenAI, Claude, Ollama)
//! - RAG system for example retrieval
//! - Nickel schema understanding
//! - Configuration generation and validation
//! - Persistent conversation storage
use crate::llm::{
prompts::system::config_assistant_system, rag_integration::format_rag_context,
GenerationOptions, LlmProvider, Message as LlmMessage,
};
use crate::storage::{MessageRole, SurrealDbClient};
use anyhow::Result;
use chrono::Utc;
use futures::stream::StreamExt;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use typedialog_core::ai::rag::RagSystem;
use super::schema_analyzer::AnalyzedSchema;
/// Structured response from the assistant
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AssistantResponse {
/// The assistant's text response
pub text: String,
/// Field suggestions extracted from the response
pub suggestions: Vec<FieldSuggestion>,
/// RAG results that informed this response
pub rag_context: Option<String>,
/// Message ID in storage
pub message_id: String,
}
/// Suggestion for a configuration field
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FieldSuggestion {
/// Field name (e.g., "port", "database_url")
pub field: String,
/// Suggested value
pub value: String,
/// Explanation of the suggestion
pub reasoning: String,
/// Confidence score (0-1)
pub confidence: f32,
}
/// Configuration generation result
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GeneratedConfig {
/// Generated configuration (JSON, YAML, or TOML)
pub content: String,
/// Format of the configuration
pub format: String,
/// Whether the configuration is valid
pub is_valid: bool,
/// Validation errors if any
pub errors: Vec<String>,
}
/// AI Configuration Assistant Engine
///
/// Manages multi-turn conversations for intelligent configuration generation
/// using RAG-powered suggestions and LLM generation.
pub struct ConfigAssistant {
/// LLM provider for text generation
llm: Arc<dyn LlmProvider>,
/// RAG system for retrieving relevant examples
rag: RagSystem,
/// Analyzed schema for the current conversation
schema: AnalyzedSchema,
/// Database client for persistence
storage: Arc<SurrealDbClient>,
/// Current conversation ID
conversation_id: String,
/// Conversation history for context
history: Vec<ConversationTurn>,
}
/// Single turn in a conversation
#[derive(Debug, Clone)]
pub struct ConversationTurn {
/// User or assistant message
role: MessageRole,
/// Message content
content: String,
/// Timestamp
timestamp: chrono::DateTime<Utc>,
}
impl ConfigAssistant {
/// Create a new configuration assistant for a conversation
///
/// # Arguments
///
/// * `llm` - LLM provider (OpenAI, Claude, Ollama)
/// * `rag` - Initialized RAG system
/// * `schema` - Analyzed schema for the conversation
/// * `storage` - Database client for persistence
/// * `conversation_id` - Unique conversation identifier
///
/// # Returns
///
/// Configured assistant ready for use
pub fn new(
llm: Arc<dyn LlmProvider>,
rag: RagSystem,
schema: AnalyzedSchema,
storage: Arc<SurrealDbClient>,
conversation_id: String,
) -> Self {
ConfigAssistant {
llm,
rag,
schema,
storage,
conversation_id,
history: Vec::new(),
}
}
/// Send a user message and get assistant response
///
/// The complete flow:
/// 1. Store user message
/// 2. Retrieve relevant examples via RAG
/// 3. Format prompt with schema and RAG context
/// 4. Generate response with LLM
/// 5. Extract field suggestions
/// 6. Store assistant response
/// 7. Return structured response
pub async fn send_message(&mut self, user_message: &str) -> Result<AssistantResponse> {
tracing::debug!(conv_id = %self.conversation_id, "Processing user message");
// 1. Store user message
let user_msg_id = self
.storage
.create_message(&self.conversation_id, MessageRole::User, user_message)
.await?;
self.history.push(ConversationTurn {
role: MessageRole::User,
content: user_message.to_string(),
timestamp: Utc::now(),
});
// 2. Retrieve relevant examples via RAG
let rag_results = self.rag.retrieve(user_message)?;
// Convert to llm::rag_integration::RetrievalResult for formatting
let llm_rag_results: Vec<_> = rag_results
.iter()
.map(|r| crate::llm::rag_integration::RetrievalResult {
doc_id: r.doc_id.clone(),
content: r.content.clone(),
combined_score: r.combined_score,
})
.collect();
let rag_context = format_rag_context(&llm_rag_results);
tracing::debug!(results = rag_results.len(), "RAG retrieval complete");
// 3. Format prompt with schema and RAG context
let system_prompt = config_assistant_system();
let schema_description = self.schema.format_for_prompt();
let context_section = if !rag_context.is_empty() {
format!("## Relevant Examples\n\n{}\n\n", rag_context)
} else {
String::new()
};
let user_prompt = format!(
"{}{}## Current Question\n\n{}",
schema_description, context_section, user_message
);
let mut messages = vec![LlmMessage::system(&system_prompt)];
// Add conversation history for context
for turn in &self.history[..self.history.len().saturating_sub(1)] {
match turn.role {
MessageRole::User => {
messages.push(LlmMessage::user(&turn.content));
}
MessageRole::Assistant => {
messages.push(LlmMessage::assistant(&turn.content));
}
}
}
messages.push(LlmMessage::user(&user_prompt));
// 4. Generate response with LLM
let options = GenerationOptions::analytical();
let response_text = self.llm.generate(&messages, &options).await?;
tracing::debug!("LLM generation complete");
// 5. Extract field suggestions
let suggestions = self.extract_suggestions(&response_text)?;
// 6. Store assistant response
let assistant_msg_id = self
.storage
.create_message(
&self.conversation_id,
MessageRole::Assistant,
&response_text,
)
.await?;
self.history.push(ConversationTurn {
role: MessageRole::Assistant,
content: response_text.clone(),
timestamp: Utc::now(),
});
// Link RAG results to user message for debugging/analytics
let rag_ids: Vec<String> = rag_results.iter().map(|r| r.doc_id.clone()).collect();
if !rag_ids.is_empty() {
self.storage.link_rag_results(&user_msg_id, rag_ids).await?;
}
Ok(AssistantResponse {
text: response_text,
suggestions,
rag_context: Some(rag_context),
message_id: assistant_msg_id,
})
}
/// Get suggestions for a specific field
///
/// Uses RAG to find relevant examples and extract suggested values
pub async fn suggest_values(&mut self, field_name: &str) -> Result<Vec<FieldSuggestion>> {
if self.schema.find_field(field_name).is_none() {
return Ok(vec![]);
}
// Query RAG for this field
let query = format!(
"examples for {} configuration with {} field",
self.schema.name, field_name
);
let rag_results = self.rag.retrieve(&query)?;
// Extract values from RAG results
let mut suggestions = Vec::new();
for result in rag_results.iter().take(3) {
// Extract values from the content
if let Some(value) = self.extract_field_value(&result.content, field_name) {
suggestions.push(FieldSuggestion {
field: field_name.to_string(),
value,
reasoning: format!("Found in example: {}", result.doc_id),
confidence: result.combined_score,
});
}
}
Ok(suggestions)
}
/// Generate final configuration from conversation
///
/// Creates a complete configuration based on the collected fields
/// and validates it against the schema
pub async fn generate_config(&mut self, format: &str) -> Result<GeneratedConfig> {
tracing::debug!(format = %format, "Generating configuration");
// Collect all values from conversation
let collected_values = self.collect_field_values()?;
// Build generation prompt
let generation_prompt = format!(
"Based on our conversation, generate a {} configuration for {}.\n\
Collected values: {}\n\
Schema: {}\n\
Ensure all required fields are included and all values match the schema constraints.",
format,
self.schema.name,
serde_json::to_string(&collected_values)?,
self.schema.format_for_prompt()
);
#[allow(clippy::useless_vec)]
let messages = vec![
LlmMessage::system(config_assistant_system()),
LlmMessage::user(&generation_prompt),
];
// Generate configuration
let options = GenerationOptions::config_generation();
let config_content = self.llm.generate(&messages, &options).await?;
tracing::debug!("Configuration generated, validating");
// TODO: Validate configuration with Nickel
// For now, mark as valid if it's not empty
let is_valid = !config_content.trim().is_empty();
Ok(GeneratedConfig {
content: config_content,
format: format.to_string(),
is_valid,
errors: vec![],
})
}
/// Get conversation history
pub fn history(&self) -> &[ConversationTurn] {
&self.history
}
/// Get conversation ID
pub fn conversation_id(&self) -> &str {
&self.conversation_id
}
/// Get schema
pub fn schema(&self) -> &AnalyzedSchema {
&self.schema
}
/// Stream a response for real-time WebSocket delivery
///
/// Similar to send_message but returns a stream instead of full response
/// for token-by-token delivery via WebSocket
pub async fn stream_message(
&mut self,
user_message: &str,
) -> Result<futures::stream::BoxStream<'static, String>> {
tracing::debug!(conv_id = %self.conversation_id, "Streaming user message");
// Store user message
let _user_msg_id = self
.storage
.create_message(&self.conversation_id, MessageRole::User, user_message)
.await?;
self.history.push(ConversationTurn {
role: MessageRole::User,
content: user_message.to_string(),
timestamp: Utc::now(),
});
// Retrieve relevant examples via RAG
let rag_results = self.rag.retrieve(user_message)?;
// Convert to llm::rag_integration::RetrievalResult for formatting
let llm_rag_results: Vec<_> = rag_results
.iter()
.map(|r| crate::llm::rag_integration::RetrievalResult {
doc_id: r.doc_id.clone(),
content: r.content.clone(),
combined_score: r.combined_score,
})
.collect();
let rag_context = format_rag_context(&llm_rag_results);
// Format prompt with schema and RAG context
let system_prompt = config_assistant_system();
let schema_description = self.schema.format_for_prompt();
let context_section = if !rag_context.is_empty() {
format!("## Relevant Examples\n\n{}\n\n", rag_context)
} else {
String::new()
};
let user_prompt = format!(
"{}{}## Current Question\n\n{}",
schema_description, context_section, user_message
);
let mut messages = vec![LlmMessage::system(&system_prompt)];
// Add conversation history for context
for turn in &self.history[..self.history.len().saturating_sub(1)] {
match turn.role {
MessageRole::User => {
messages.push(LlmMessage::user(&turn.content));
}
MessageRole::Assistant => {
messages.push(LlmMessage::assistant(&turn.content));
}
}
}
messages.push(LlmMessage::user(&user_prompt));
// Create stream
let options = GenerationOptions::analytical();
let stream = self.llm.stream(&messages, &options).await?;
// Map stream to handle Result types - take only Ok values
let mapped_stream = stream
.filter_map(|result| async move {
match result {
Ok(chunk) => Some(chunk),
Err(e) => {
tracing::error!("Stream error: {:?}", e);
None
}
}
})
.boxed();
Ok(mapped_stream)
}
// ============================================================================
// PRIVATE HELPER METHODS
// ============================================================================
/// Format conversation history for inclusion in prompts
fn format_conversation_history(&self) -> String {
if self.history.len() <= 1 {
return String::new();
}
let mut formatted = String::from("## Conversation History\n\n");
for turn in &self.history[..self.history.len().saturating_sub(1)] {
match turn.role {
MessageRole::User => {
formatted.push_str(&format!("**User:** {}\n\n", turn.content));
}
MessageRole::Assistant => {
formatted.push_str(&format!("**Assistant:** {}\n\n", turn.content));
}
}
}
formatted
}
/// Extract field suggestions from LLM response
fn extract_suggestions(&self, response: &str) -> Result<Vec<FieldSuggestion>> {
let mut suggestions = Vec::new();
// Simple pattern matching for field suggestions
// Looks for patterns like "field_name: value" or "field_name should be value"
for field in self.schema.fields.iter() {
if let Some(pos) = response.find(&field.flat_name) {
// Look for value near the field name
let context = &response[pos..std::cmp::min(pos + 200, response.len())];
if let Some(value) = self.extract_value_from_context(context, &field.flat_name) {
suggestions.push(FieldSuggestion {
field: field.flat_name.clone(),
value,
reasoning: "Extracted from LLM response".to_string(),
confidence: 0.7,
});
}
}
}
Ok(suggestions)
}
/// Extract field value from RAG result content
fn extract_field_value(&self, content: &str, field_name: &str) -> Option<String> {
// Look for patterns like: field_name = "value" or field_name: value
let patterns = vec![
format!("{} = \"", field_name),
format!("{} = '", field_name),
format!("{} = ", field_name),
format!("{}: ", field_name),
];
for pattern in patterns {
if let Some(pos) = content.find(&pattern) {
let start = pos + pattern.len();
let remainder = &content[start..];
// Extract value until quote or line end
if pattern.contains("\"") {
if let Some(end) = remainder.find('"') {
return Some(remainder[..end].to_string());
}
} else if pattern.contains("'") {
if let Some(end) = remainder.find('\'') {
return Some(remainder[..end].to_string());
}
} else {
// Find next whitespace or comma
let end = remainder
.find(|c: char| c.is_whitespace() || c == ',' || c == ';')
.unwrap_or(remainder.len());
let value = remainder[..end].trim().to_string();
if !value.is_empty() {
return Some(value);
}
}
}
}
None
}
/// Extract value from response context
fn extract_value_from_context(&self, context: &str, _field_name: &str) -> Option<String> {
// Simple extraction: look for colon followed by value
if let Some(colon_pos) = context.find(':') {
let after_colon = &context[colon_pos + 1..];
let value = after_colon
.split_whitespace()
.next()
.unwrap_or("")
.trim_matches(|c: char| !c.is_alphanumeric() && c != '_' && c != '-' && c != '.');
if !value.is_empty() {
return Some(value.to_string());
}
}
None
}
/// Collect field values from conversation history
fn collect_field_values(&self) -> Result<serde_json::Value> {
let mut values = serde_json::json!({});
// Extract values from all assistant responses
for turn in &self.history {
if turn.role == MessageRole::Assistant {
for suggestion in self.extract_suggestions(&turn.content)? {
values[suggestion.field] = serde_json::json!(suggestion.value);
}
}
}
Ok(values)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::assistant::schema_analyzer::SchemaAnalyzer;
use typedialog_core::nickel::schema_ir::{NickelFieldIR, NickelSchemaIR, NickelType};
fn create_test_schema() -> AnalyzedSchema {
let fields = vec![NickelFieldIR {
path: vec!["port".to_string()],
flat_name: "port".to_string(),
alias: Some("Port".to_string()),
nickel_type: NickelType::Number,
doc: Some("Server port".to_string()),
default: None,
optional: false,
contract: None,
contract_call: None,
group: None,
fragment_marker: None,
is_array_of_records: false,
array_element_fields: None,
encryption_metadata: None,
}];
let schema_ir = NickelSchemaIR {
name: "ServerConfig".to_string(),
description: Some("Server configuration".to_string()),
fields,
};
SchemaAnalyzer::analyze_schema_ir(schema_ir).unwrap()
}
#[test]
fn test_assistant_response_serialization() {
let response = AssistantResponse {
text: "Configuration looks good".to_string(),
suggestions: vec![FieldSuggestion {
field: "port".to_string(),
value: "8080".to_string(),
reasoning: "Standard web port".to_string(),
confidence: 0.9,
}],
rag_context: Some("Example context".to_string()),
message_id: "msg-123".to_string(),
};
let json = serde_json::to_string(&response).unwrap();
drop(serde_json::from_str::<AssistantResponse>(&json).unwrap());
}
#[test]
fn test_generated_config_structure() {
let config = GeneratedConfig {
content: "port = 8080".to_string(),
format: "toml".to_string(),
is_valid: true,
errors: vec![],
};
assert!(config.is_valid);
assert_eq!(config.format, "toml");
}
#[test]
fn test_field_suggestion_serialization() {
let suggestion = FieldSuggestion {
field: "database".to_string(),
value: "localhost".to_string(),
reasoning: "Default development server".to_string(),
confidence: 0.85,
};
let json = serde_json::to_string(&suggestion).unwrap();
let deserialized: FieldSuggestion = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.field, "database");
}
#[test]
fn test_extract_field_value_equals_quote() {
let content = r#"port = "8080""#;
// Verify content parsing
assert!(content.contains("8080"));
}
#[test]
fn test_conversation_turn_storage() {
let turn = ConversationTurn {
role: MessageRole::User,
content: "What port should I use?".to_string(),
timestamp: Utc::now(),
};
assert_eq!(turn.role, MessageRole::User);
assert!(turn.content.contains("port"));
}
#[test]
fn test_message_role_equality() {
assert_eq!(MessageRole::User, MessageRole::User);
assert_ne!(MessageRole::User, MessageRole::Assistant);
}
}

View File

@ -0,0 +1,211 @@
//! RAG Knowledge Base Indexer
//!
//! Indexes Nickel schemas and configuration examples into the RAG system
//! for retrieval-augmented generation. Supports both in-memory and persistent
//! indexing for efficient startup and search.
use anyhow::Result;
use std::path::PathBuf;
use typedialog_core::ai::rag::RagSystem;
/// Build knowledge base from available schemas and examples
///
/// Indexes all Nickel schema files and TOML configuration examples
/// found in the project's examples directories.
pub async fn build_knowledge_base(rag: &mut RagSystem) -> Result<()> {
tracing::info!("Building RAG knowledge base from project examples");
let mut documents = Vec::new();
let _doc_count = documents.len();
// Index Nickel schemas - looking in common example directories
let schema_dirs = vec![
"examples/07-nickel-generation",
"examples/nickel",
"examples/schemas",
];
for dir in schema_dirs {
if let Ok(entries) = std::fs::read_dir(dir) {
for entry in entries.flatten() {
let path = entry.path();
if path.extension().map(|ext| ext == "ncl").unwrap_or(false) {
if let Ok(content) = std::fs::read_to_string(&path) {
let doc_id = path.display().to_string();
tracing::debug!(path = %doc_id, "Indexing Nickel schema");
documents.push((doc_id, content));
}
}
}
}
}
// Index TOML configuration examples
let form_dirs = vec!["examples", "examples/forms", "examples/configs"];
for dir in form_dirs {
if let Ok(entries) = std::fs::read_dir(dir) {
for entry in entries.flatten() {
let path = entry.path();
if path.extension().map(|ext| ext == "toml").unwrap_or(false) {
if let Ok(content) = std::fs::read_to_string(&path) {
let doc_id = path.display().to_string();
tracing::debug!(path = %doc_id, "Indexing TOML configuration");
documents.push((doc_id, content));
}
}
}
}
}
if documents.is_empty() {
tracing::warn!("No documents found to index");
return Ok(());
}
tracing::info!(count = documents.len(), "Adding documents to RAG index");
// Batch add documents for efficiency
rag.add_documents_batch(documents)?;
// Save to persistent storage for faster startup
let cache_path = ".typedialog/knowledge-base.bin";
if let Some(parent) = PathBuf::from(cache_path).parent() {
drop(std::fs::create_dir_all(parent));
}
rag.save_to_file(cache_path)?;
tracing::info!(path = cache_path, "RAG knowledge base cached");
Ok(())
}
/// Load knowledge base from cache if available
///
/// Attempts to load a previously indexed knowledge base from disk.
/// Falls back to building a new index if cache is unavailable.
pub async fn load_or_build_knowledge_base(rag: &mut RagSystem) -> Result<()> {
let cache_path = ".typedialog/knowledge-base.bin";
// Try to load from cache
if PathBuf::from(cache_path).exists() {
match RagSystem::load_from_file(cache_path) {
Ok(cached_rag) => {
tracing::info!(path = cache_path, "Loaded RAG knowledge base from cache");
*rag = cached_rag;
return Ok(());
}
Err(e) => {
tracing::warn!(error = %e, "Failed to load cached knowledge base, rebuilding");
}
}
}
// Build new index if cache unavailable or corrupted
build_knowledge_base(rag).await
}
/// Get statistics about the indexed knowledge base
///
/// Returns information about what's been indexed including
/// document counts by type.
pub struct KnowledgeBaseStats {
/// Total number of indexed documents
pub total_documents: usize,
/// Number of Nickel schemas
pub schema_count: usize,
/// Number of TOML configurations
pub config_count: usize,
/// Cache file path
pub cache_path: String,
/// Whether cache exists
pub cache_exists: bool,
}
/// Get knowledge base statistics
pub fn get_knowledge_base_stats() -> KnowledgeBaseStats {
let cache_path = ".typedialog/knowledge-base.bin";
// Count documents by scanning directories
let mut schema_count = 0;
let mut config_count = 0;
let schema_dirs = vec![
"examples/07-nickel-generation",
"examples/nickel",
"examples/schemas",
];
for dir in schema_dirs {
if let Ok(entries) = std::fs::read_dir(dir) {
for entry in entries.flatten() {
let path = entry.path();
if path.extension().map(|ext| ext == "ncl").unwrap_or(false) {
schema_count += 1;
}
}
}
}
let form_dirs = vec!["examples", "examples/forms", "examples/configs"];
for dir in form_dirs {
if let Ok(entries) = std::fs::read_dir(dir) {
for entry in entries.flatten() {
let path = entry.path();
if path.extension().map(|ext| ext == "toml").unwrap_or(false) {
config_count += 1;
}
}
}
}
let cache_exists = PathBuf::from(cache_path).exists();
KnowledgeBaseStats {
total_documents: schema_count + config_count,
schema_count,
config_count,
cache_path: cache_path.to_string(),
cache_exists,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_knowledge_base_stats_creation() {
let stats = get_knowledge_base_stats();
assert_eq!(stats.cache_path, ".typedialog/knowledge-base.bin");
}
#[test]
fn test_knowledge_base_stats_total() {
let stats = get_knowledge_base_stats();
assert_eq!(
stats.total_documents,
stats.schema_count + stats.config_count
);
}
#[test]
fn test_cache_path_consistency() {
let stats = get_knowledge_base_stats();
assert!(stats.cache_path.contains("knowledge-base.bin"));
}
#[test]
fn test_stats_has_expected_fields() {
let stats = get_knowledge_base_stats();
let _ = stats.schema_count;
let _ = stats.config_count;
let _ = stats.total_documents;
// Stats created successfully with all fields
}
}

View File

@ -0,0 +1,19 @@
//! AI configuration assistant components
//!
//! Provides the core assistant logic integrating RAG, LLM, and Nickel schema understanding.
pub mod engine;
pub mod indexer;
pub mod schema_analyzer;
pub mod validator;
#[allow(unused_imports)]
pub use engine::{AssistantResponse, ConfigAssistant, FieldSuggestion, GeneratedConfig};
#[allow(unused_imports)]
pub use indexer::{
build_knowledge_base, get_knowledge_base_stats, load_or_build_knowledge_base,
KnowledgeBaseStats,
};
pub use schema_analyzer::SchemaAnalyzer;
#[allow(unused_imports)]
pub use validator::{ConfigValidator, ValidationResult};

View File

@ -0,0 +1,514 @@
//! Schema analysis module for understanding Nickel configuration schemas
//!
//! Wraps the existing Nickel integration (`NickelCli`, `MetadataParser`) to:
//! - Load and parse Nickel schema files
//! - Generate natural language descriptions for LLM prompts
//! - Extract field metadata and constraints
//! - Format schema information for prompt engineering
use anyhow::Result;
use std::path::Path;
use typedialog_core::nickel::cli::NickelCli;
use typedialog_core::nickel::parser::MetadataParser;
use typedialog_core::nickel::schema_ir::{NickelFieldIR, NickelSchemaIR, NickelType};
/// Analyzes Nickel configuration schemas and generates descriptions
#[derive(Debug, Clone)]
pub struct SchemaAnalyzer;
/// Analyzed schema with natural language descriptions
#[derive(Debug, Clone)]
pub struct AnalyzedSchema {
/// Schema name
pub name: String,
/// Natural language description of the entire schema
pub description: String,
/// Individual field analyses
pub fields: Vec<AnalyzedField>,
/// Field categories grouped by semantic meaning
pub categories: Vec<FieldCategory>,
}
/// Analysis of a single schema field
#[derive(Debug, Clone)]
pub struct AnalyzedField {
/// Field path (e.g., ["user", "name"] for user.name)
pub path: Vec<String>,
/// Flattened field name for display
pub flat_name: String,
/// Natural language description of the field
pub description: String,
/// Type description (e.g., "string", "number", "array of strings")
pub type_description: String,
/// Whether the field is required
pub required: bool,
/// Default value as string representation
pub default_value: Option<String>,
/// Constraints and validations
pub constraints: Vec<String>,
/// Field category (grouping)
pub category: Option<String>,
}
/// Grouped collection of fields
#[derive(Debug, Clone)]
pub struct FieldCategory {
/// Category name
pub name: String,
/// Fields in this category
pub field_names: Vec<String>,
/// Category description
pub description: String,
}
impl SchemaAnalyzer {
/// Load and analyze a Nickel schema file
///
/// # Arguments
///
/// * `schema_path` - Path to the .ncl schema file
///
/// # Returns
///
/// Analyzed schema with natural language descriptions
///
/// # Errors
///
/// Returns error if the schema file cannot be read or parsed
pub fn analyze_file(schema_path: &Path) -> Result<AnalyzedSchema> {
// Query schema metadata using existing Nickel CLI
let metadata_json = NickelCli::query(schema_path, None)?;
// Parse JSON into structured schema IR
let schema_ir = MetadataParser::parse(metadata_json)?;
// Generate natural language analysis
Self::analyze_schema_ir(schema_ir)
}
/// Analyze an already-parsed schema IR
pub fn analyze_schema_ir(schema: NickelSchemaIR) -> Result<AnalyzedSchema> {
// Generate schema-level description
let description = Self::describe_schema(&schema);
// Analyze each field
let fields: Vec<AnalyzedField> = schema.fields.iter().map(Self::analyze_field).collect();
// Group fields by category
let categories = Self::group_fields_by_category(&fields);
Ok(AnalyzedSchema {
name: schema.name,
description,
fields,
categories,
})
}
/// Generate a natural language description of the schema
fn describe_schema(schema: &NickelSchemaIR) -> String {
let field_count = schema.fields.len();
let required_count = schema.fields.iter().filter(|f| !f.optional).count();
let mut parts = vec![];
if let Some(doc) = &schema.description {
parts.push(doc.clone());
} else {
parts.push(format!("Configuration schema with {} fields", field_count));
}
parts.push(format!(
"{} required, {} optional",
required_count,
field_count - required_count
));
// List field names for quick reference
let field_names: Vec<String> = schema
.fields
.iter()
.map(|f| {
if f.optional {
format!("{}?", f.flat_name)
} else {
f.flat_name.clone()
}
})
.collect();
if !field_names.is_empty() {
parts.push(format!("Fields: {}", field_names.join(", ")));
}
parts.join(". ")
}
/// Analyze a single field and generate descriptions
fn analyze_field(field: &NickelFieldIR) -> AnalyzedField {
let type_description = Self::describe_type(&field.nickel_type);
let constraints = Self::extract_constraints(field);
let default_value = field.default.as_ref().map(|v| match v {
serde_json::Value::String(s) => format!("\"{}\"", s),
serde_json::Value::Number(n) => n.to_string(),
serde_json::Value::Bool(b) => b.to_string(),
serde_json::Value::Null => "null".to_string(),
serde_json::Value::Array(_) => "[...]".to_string(),
serde_json::Value::Object(_) => "{...}".to_string(),
});
let description = if let Some(doc) = &field.doc {
doc.clone()
} else {
format!(
"{}{}{}",
Self::capitalize_first_letter(&field.flat_name),
if field.optional { " (optional)" } else { "" },
if !constraints.is_empty() {
format!(" with constraints: {}", constraints.join(", "))
} else {
String::new()
}
)
};
AnalyzedField {
path: field.path.clone(),
flat_name: field.flat_name.clone(),
description,
type_description,
required: !field.optional,
default_value,
constraints,
category: field.group.clone(),
}
}
/// Generate a description of a Nickel type
fn describe_type(nickel_type: &NickelType) -> String {
match nickel_type {
NickelType::String => "string".to_string(),
NickelType::Number => "number".to_string(),
NickelType::Bool => "boolean".to_string(),
NickelType::Array(elem_type) => {
format!("array of {}", Self::describe_type(elem_type))
}
NickelType::Record(_) => "object/record".to_string(),
NickelType::Custom(name) => format!("custom type: {}", name),
}
}
/// Extract constraint descriptions from a field
fn extract_constraints(field: &NickelFieldIR) -> Vec<String> {
let mut constraints = vec![];
// Add contract/predicate constraints
if let Some(contract) = &field.contract_call {
constraints.push(format!(
"{}{}",
contract.function,
if let Some(args) = &contract.args {
format!("({})", args)
} else {
String::new()
}
));
}
// Add encryption constraint if present
if let Some(enc) = &field.encryption_metadata {
if enc.sensitive {
let backend = enc
.backend
.as_ref()
.map(|b| format!(" via {}", b))
.unwrap_or_default();
constraints.push(format!("sensitive/encrypted{}", backend));
}
}
// Add array-of-records marker
if field.is_array_of_records {
constraints.push("repeating group".to_string());
}
constraints
}
/// Group fields into categories based on semantic grouping
fn group_fields_by_category(fields: &[AnalyzedField]) -> Vec<FieldCategory> {
use std::collections::HashMap;
let mut groups: HashMap<String, Vec<String>> = HashMap::new();
for field in fields {
let category = field
.category
.clone()
.unwrap_or_else(|| "General".to_string());
groups
.entry(category)
.or_default()
.push(field.flat_name.clone());
}
groups
.into_iter()
.map(|(name, field_names)| FieldCategory {
description: format!("{} configuration", name),
name,
field_names,
})
.collect()
}
/// Capitalize the first letter of a string
fn capitalize_first_letter(s: &str) -> String {
let mut chars = s.chars();
match chars.next() {
None => String::new(),
Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
}
}
}
impl AnalyzedSchema {
/// Format the schema as a prompt for an LLM
///
/// Returns a markdown-formatted description suitable for inclusion in LLM prompts
pub fn format_for_prompt(&self) -> String {
let mut lines = vec![];
// Schema description
lines.push(format!("# {}", self.name));
lines.push(String::new());
lines.push(self.description.clone());
lines.push(String::new());
// Field groups by category
if !self.categories.is_empty() {
lines.push("## Configuration Sections".to_string());
lines.push(String::new());
for category in &self.categories {
lines.push(format!("### {}", category.name));
lines.push(String::new());
for field_name in &category.field_names {
if let Some(field) = self.fields.iter().find(|f| &f.flat_name == field_name) {
let required = if field.required {
"**required**"
} else {
"optional"
};
lines.push(format!(
"- **{}** ({}): {} [{}]",
field.flat_name, required, field.description, field.type_description
));
if !field.constraints.is_empty() {
lines
.push(format!(" - Constraints: {}", field.constraints.join(", ")));
}
if let Some(default) = &field.default_value {
lines.push(format!(" - Default: {}", default));
}
}
}
lines.push(String::new());
}
} else {
// No categories, list all fields
lines.push("## Configuration Fields".to_string());
lines.push(String::new());
for field in &self.fields {
let required = if field.required {
"**required**"
} else {
"optional"
};
lines.push(format!(
"- **{}** ({}): {} [{}]",
field.flat_name, required, field.description, field.type_description
));
if !field.constraints.is_empty() {
lines.push(format!(" - Constraints: {}", field.constraints.join(", ")));
}
if let Some(default) = &field.default_value {
lines.push(format!(" - Default: {}", default));
}
}
lines.push(String::new());
}
lines.join("\n")
}
/// Get a summary of required fields for quick reference
pub fn required_fields(&self) -> Vec<&AnalyzedField> {
self.fields.iter().filter(|f| f.required).collect()
}
/// Get all optional fields
pub fn optional_fields(&self) -> Vec<&AnalyzedField> {
self.fields.iter().filter(|f| !f.required).collect()
}
/// Find a field by its flat name
pub fn find_field(&self, name: &str) -> Option<&AnalyzedField> {
self.fields.iter().find(|f| f.flat_name == name)
}
}
#[cfg(test)]
mod tests {
use super::*;
use typedialog_core::nickel::schema_ir::{ContractCall, NickelFieldIR, NickelSchemaIR};
fn create_test_field(name: &str, optional: bool) -> NickelFieldIR {
NickelFieldIR {
path: vec![name.to_string()],
flat_name: name.to_string(),
alias: None,
nickel_type: NickelType::String,
doc: Some(format!("Test field: {}", name)),
default: None,
optional,
contract: None,
contract_call: None,
group: None,
fragment_marker: None,
is_array_of_records: false,
array_element_fields: None,
encryption_metadata: None,
}
}
fn create_test_schema() -> NickelSchemaIR {
NickelSchemaIR {
name: "TestSchema".to_string(),
description: Some("A test configuration schema".to_string()),
fields: vec![
create_test_field("username", false),
create_test_field("password", false),
create_test_field("description", true),
],
}
}
#[test]
fn test_describe_schema() {
let schema = create_test_schema();
let desc = SchemaAnalyzer::describe_schema(&schema);
assert!(desc.contains("A test configuration schema"));
assert!(desc.contains("2 required"));
assert!(desc.contains("1 optional"));
assert!(desc.contains("username"));
assert!(desc.contains("password"));
}
#[test]
fn test_analyze_field() {
let field = create_test_field("test_field", false);
let analyzed = SchemaAnalyzer::analyze_field(&field);
assert_eq!(analyzed.flat_name, "test_field");
assert_eq!(analyzed.type_description, "string");
assert!(analyzed.required); // optional=false means required=true
assert!(analyzed.description.contains("Test field"));
}
#[test]
fn test_describe_types() {
assert_eq!(SchemaAnalyzer::describe_type(&NickelType::String), "string");
assert_eq!(SchemaAnalyzer::describe_type(&NickelType::Number), "number");
assert_eq!(SchemaAnalyzer::describe_type(&NickelType::Bool), "boolean");
let array_type = NickelType::Array(Box::new(NickelType::String));
assert_eq!(
SchemaAnalyzer::describe_type(&array_type),
"array of string"
);
}
#[test]
fn test_extract_constraints_with_contract() {
let mut field = create_test_field("port", false);
field.contract_call = Some(ContractCall {
module: "validators".to_string(),
function: "ValidPort".to_string(),
args: Some("80".to_string()),
expr: "validators.ValidPort 80".to_string(),
});
let constraints = SchemaAnalyzer::extract_constraints(&field);
assert_eq!(constraints.len(), 1);
assert_eq!(constraints[0], "ValidPort(80)");
}
#[test]
fn test_analyzed_schema_required_fields() {
let schema = create_test_schema();
let analyzed = SchemaAnalyzer::analyze_schema_ir(schema).unwrap();
let required = analyzed.required_fields();
assert_eq!(required.len(), 2);
let optional = analyzed.optional_fields();
assert_eq!(optional.len(), 1);
}
#[test]
fn test_analyzed_schema_find_field() {
let schema = create_test_schema();
let analyzed = SchemaAnalyzer::analyze_schema_ir(schema).unwrap();
let field = analyzed.find_field("username");
assert!(field.is_some());
assert_eq!(field.unwrap().flat_name, "username");
let missing = analyzed.find_field("nonexistent");
assert!(missing.is_none());
}
#[test]
fn test_format_for_prompt() {
let schema = create_test_schema();
let analyzed = SchemaAnalyzer::analyze_schema_ir(schema).unwrap();
let formatted = analyzed.format_for_prompt();
assert!(formatted.contains("# TestSchema"));
assert!(formatted.contains("A test configuration schema"));
assert!(formatted.contains("username"));
assert!(formatted.contains("**required**"));
assert!(formatted.contains("optional"));
assert!(formatted.contains("string"));
}
#[test]
fn test_capitalize_first_letter() {
assert_eq!(SchemaAnalyzer::capitalize_first_letter("test"), "Test");
assert_eq!(SchemaAnalyzer::capitalize_first_letter("t"), "T");
assert_eq!(SchemaAnalyzer::capitalize_first_letter(""), "");
}
}

View File

@ -0,0 +1,294 @@
//! Configuration validation using Nickel typechecking
//!
//! Validates generated configurations against schema constraints
//! using the existing Nickel CLI integration.
use anyhow::Result;
use std::path::Path;
use typedialog_core::nickel::cli::NickelCli;
/// Validation result for a configuration
#[derive(Debug, Clone)]
pub struct ValidationResult {
/// Whether the configuration is valid
pub is_valid: bool,
/// Validation error messages if any
pub errors: Vec<String>,
/// Validation warnings if any
pub warnings: Vec<String>,
}
impl ValidationResult {
/// Create a successful validation result
pub fn success() -> Self {
ValidationResult {
is_valid: true,
errors: vec![],
warnings: vec![],
}
}
/// Create a failed validation result
pub fn failure(errors: Vec<String>) -> Self {
ValidationResult {
is_valid: false,
errors,
warnings: vec![],
}
}
/// Add a warning to the result
pub fn with_warning(mut self, warning: String) -> Self {
self.warnings.push(warning);
self
}
}
/// Configuration validator
pub struct ConfigValidator;
impl ConfigValidator {
/// Validate a configuration file against a Nickel schema
///
/// # Arguments
///
/// * `config_path` - Path to the configuration file (JSON, YAML, or TOML)
/// * `schema_path` - Path to the Nickel schema (.ncl file)
///
/// # Returns
///
/// Validation result with any errors or warnings
///
/// # Errors
///
/// Returns error if the files cannot be read or nickel validation fails
pub fn validate_config(config_path: &Path, schema_path: &Path) -> Result<ValidationResult> {
tracing::debug!(
config = %config_path.display(),
schema = %schema_path.display(),
"Validating configuration"
);
// Use nickel typecheck to validate
match NickelCli::verify() {
Ok(_) => {
// Nickel is available, use it for validation
Self::validate_with_nickel(config_path, schema_path)
}
Err(_) => {
// Nickel not available, do basic validation
tracing::warn!("Nickel CLI not available, skipping typecheck");
Ok(ValidationResult::success().with_warning(
"Nickel CLI not available - configuration not typechecked".to_string(),
))
}
}
}
/// Validate configuration string against schema
///
/// # Arguments
///
/// * `config_content` - Configuration content as string
/// * `format` - Configuration format ("json", "yaml", "toml")
/// * `schema_path` - Path to the Nickel schema
///
/// # Returns
///
/// Validation result
pub fn validate_string(
config_content: &str,
format: &str,
_schema_path: &Path,
) -> Result<ValidationResult> {
// Basic validation without file operations
Ok(Self::validate_content(config_content, format))
}
/// Validate that configuration has all required fields from schema
///
/// # Arguments
///
/// * `config_content` - Configuration content as string
/// * `required_fields` - List of required field names
///
/// # Returns
///
/// Validation result
pub fn validate_required_fields(
config_content: &str,
required_fields: &[&str],
) -> Result<ValidationResult> {
let mut missing_fields = Vec::new();
for field in required_fields {
// Check if field appears in the configuration
let patterns = [
format!("{} =", field),
format!("{}: ", field),
format!("\"{}\": ", field),
format!("'{}': ", field),
];
let found = patterns.iter().any(|p| config_content.contains(p));
if !found {
missing_fields.push(format!("Missing required field: {}", field));
}
}
if missing_fields.is_empty() {
Ok(ValidationResult::success())
} else {
Ok(ValidationResult::failure(missing_fields))
}
}
// ============================================================================
// PRIVATE HELPER METHODS
// ============================================================================
/// Validate configuration using Nickel CLI
fn validate_with_nickel(_config_path: &Path, _schema_path: &Path) -> Result<ValidationResult> {
// TODO: Implement actual Nickel validation
// For now, return success as a placeholder
Ok(ValidationResult::success())
}
/// Validate configuration content without file I/O
fn validate_content(config_content: &str, format: &str) -> ValidationResult {
let mut errors = Vec::new();
// Basic format-specific validation
match format {
"json" => {
if let Err(e) = serde_json::from_str::<serde_json::Value>(config_content) {
errors.push(format!("Invalid JSON: {}", e));
}
}
"yaml" => {
if let Err(e) = serde_yaml::from_str::<serde_yaml::Value>(config_content) {
errors.push(format!("Invalid YAML: {}", e));
}
}
"toml" => {
if let Err(e) = toml::from_str::<toml::Table>(config_content) {
errors.push(format!("Invalid TOML: {}", e));
}
}
_ => {
errors.push(format!("Unknown configuration format: {}", format));
}
}
if errors.is_empty() {
ValidationResult::success()
} else {
ValidationResult::failure(errors)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validation_result_success() {
let result = ValidationResult::success();
assert!(result.is_valid);
assert!(result.errors.is_empty());
}
#[test]
fn test_validation_result_failure() {
let errors = vec!["Error 1".to_string(), "Error 2".to_string()];
let result = ValidationResult::failure(errors);
assert!(!result.is_valid);
assert_eq!(result.errors.len(), 2);
}
#[test]
fn test_validation_result_with_warning() {
let result = ValidationResult::success().with_warning("A warning".to_string());
assert!(result.is_valid);
assert_eq!(result.warnings.len(), 1);
}
#[test]
fn test_validate_json_valid() {
let valid_json = r#"{"name": "test", "value": 42}"#;
let result = ConfigValidator::validate_content(valid_json, "json");
assert!(result.is_valid);
}
#[test]
fn test_validate_json_invalid() {
let invalid_json = r#"{"name": "test", "value": 42"#;
let result = ConfigValidator::validate_content(invalid_json, "json");
assert!(!result.is_valid);
assert!(!result.errors.is_empty());
}
#[test]
fn test_validate_yaml_valid() {
let valid_yaml = "name: test\nvalue: 42";
let result = ConfigValidator::validate_content(valid_yaml, "yaml");
assert!(result.is_valid);
}
#[test]
fn test_validate_yaml_invalid() {
let invalid_yaml = "name: test\n value: 42\n invalid";
let result = ConfigValidator::validate_content(invalid_yaml, "yaml");
assert!(!result.is_valid);
}
#[test]
fn test_validate_toml_valid() {
let valid_toml = "name = \"test\"\nvalue = 42";
let result = ConfigValidator::validate_content(valid_toml, "toml");
assert!(result.is_valid);
}
#[test]
fn test_validate_toml_invalid() {
let invalid_toml = "name = \"test\"\nvalue = ";
let result = ConfigValidator::validate_content(invalid_toml, "toml");
assert!(!result.is_valid);
}
#[test]
fn test_validate_unknown_format() {
let content = "some content";
let result = ConfigValidator::validate_content(content, "unknown");
assert!(!result.is_valid);
}
#[test]
fn test_validate_required_fields_all_present() {
let config = r#"{"username": "admin", "password": "secret"}"#;
let result =
ConfigValidator::validate_required_fields(config, &["username", "password"]).unwrap();
assert!(result.is_valid);
}
#[test]
fn test_validate_required_fields_missing() {
let config = r#"{"username": "admin"}"#;
let result =
ConfigValidator::validate_required_fields(config, &["username", "password"]).unwrap();
assert!(!result.is_valid);
assert!(result.errors[0].contains("password"));
}
#[test]
fn test_validate_required_fields_yaml_format() {
let config = "username: admin\npassword: secret";
let result =
ConfigValidator::validate_required_fields(config, &["username", "password"]).unwrap();
assert!(result.is_valid);
}
}

View File

@ -0,0 +1,274 @@
//! AI-powered FormBackend implementation
//!
//! Provides intelligent form field assistance using LLM + RAG.
use crate::llm::{GenerationOptions, LlmProvider, Message};
use crate::storage::SurrealDbClient;
use async_trait::async_trait;
use serde_json::json;
use std::collections::HashMap;
use std::sync::Arc;
use typedialog_core::ai::rag::RagSystem;
use typedialog_core::backends::{FormBackend, RenderContext};
use typedialog_core::error::Result;
use typedialog_core::form_parser::{DisplayItem, FieldDefinition, FieldType};
/// Interaction mode for AI backend
#[derive(Debug, Clone, Copy)]
pub enum AiBackendMode {
/// Interactive: LLM suggests, user can override
Interactive,
/// AutoComplete: LLM generates all values
AutoComplete,
/// ValidateOnly: User provides, LLM validates
ValidateOnly,
}
/// AI-powered FormBackend
pub struct AiBackend {
llm: Arc<dyn LlmProvider>,
#[allow(dead_code)]
rag: Option<RagSystem>,
#[allow(dead_code)]
storage: Option<Arc<SurrealDbClient>>,
#[allow(dead_code)]
conversation_id: Option<String>,
#[allow(dead_code)]
results: HashMap<String, serde_json::Value>,
mode: AiBackendMode,
}
impl AiBackend {
/// Create simple backend (no RAG, no storage)
pub fn new_simple(llm: Arc<dyn LlmProvider>) -> Self {
Self {
llm,
rag: None,
storage: None,
conversation_id: None,
results: HashMap::new(),
mode: AiBackendMode::Interactive,
}
}
/// Create full backend with RAG + storage
pub fn new_full(
llm: Arc<dyn LlmProvider>,
rag: RagSystem,
storage: Arc<SurrealDbClient>,
) -> Self {
Self {
llm,
rag: Some(rag),
storage: Some(storage),
conversation_id: None,
results: HashMap::new(),
mode: AiBackendMode::Interactive,
}
}
/// Set interaction mode
pub fn with_mode(mut self, mode: AiBackendMode) -> Self {
self.mode = mode;
self
}
}
#[async_trait]
impl FormBackend for AiBackend {
async fn initialize(&mut self) -> Result<()> {
// Create conversation if storage available
if let Some(ref storage) = self.storage {
let conv_id = storage
.create_conversation("ai-form")
.await
.map_err(|e| typedialog_core::error::ErrorWrapper::new(e.to_string()))?;
self.conversation_id = Some(conv_id);
}
println!("🤖 AI Backend initialized");
Ok(())
}
async fn render_display_item(
&self,
item: &DisplayItem,
_context: &RenderContext,
) -> Result<()> {
// Simple display rendering
if let Some(content) = &item.content {
println!("{}", content);
}
Ok(())
}
#[allow(clippy::result_large_err)]
async fn execute_field(
&self,
field: &FieldDefinition,
context: &RenderContext,
) -> Result<serde_json::Value> {
println!("\n🔍 Analyzing field: {}", field.prompt);
// NOTE: RAG context would require mutable access to self.rag
// For now we skip RAG in the interactive flow
// In production, use RefCell<RagSystem> or Arc<Mutex<RagSystem>>
let rag_context: Option<Vec<_>> = None;
// Build prompt for LLM
let prompt = self.build_field_prompt(field, context, rag_context.as_ref());
let messages = vec![
Message::system("You are a configuration assistant. Suggest optimal field values."),
Message::user(&prompt),
];
// Get LLM suggestion
let suggestion = self
.llm
.generate(&messages, &GenerationOptions::default())
.await
.map_err(|e| typedialog_core::error::ErrorWrapper::new(e.to_string()))?;
// Display suggestion and get user input
println!(
"💡 AI Suggestion: {}",
suggestion.lines().next().unwrap_or(&suggestion)
);
match self.mode {
AiBackendMode::Interactive => {
// Interactive: show suggestion, let user override
use dialoguer::Input;
let value: String = Input::new()
.with_prompt(format!("{} (press Enter for suggestion)", field.prompt))
.default(suggestion.lines().next().unwrap_or("").to_string())
.interact_text()
.map_err(|e| typedialog_core::error::ErrorWrapper::new(e.to_string()))?;
Ok(self.parse_value(&value, field)?)
}
AiBackendMode::AutoComplete => {
// AutoComplete: use suggestion directly
println!("✓ Using AI suggestion");
Ok(self.parse_value(suggestion.lines().next().unwrap_or(""), field)?)
}
AiBackendMode::ValidateOnly => {
// ValidateOnly: ask user first, then validate with LLM
use dialoguer::Input;
let value: String = Input::new()
.with_prompt(&field.prompt)
.interact_text()
.map_err(|e| typedialog_core::error::ErrorWrapper::new(e.to_string()))?;
// Validate with LLM
let validation_msg = format!(
"The user provided '{}' for field '{}'. Is this a valid value? Respond with YES or NO.",
value, field.name
);
let validation_response = self
.llm
.generate(
&[Message::user(&validation_msg)],
&GenerationOptions::default(),
)
.await
.map_err(|e| typedialog_core::error::ErrorWrapper::new(e.to_string()))?;
if validation_response.to_uppercase().contains("YES") {
println!("✓ Value validated");
} else {
println!("⚠ Warning: LLM suggests this might not be ideal");
}
Ok(self.parse_value(&value, field)?)
}
}
}
async fn shutdown(&mut self) -> Result<()> {
println!("🤖 AI Backend shutdown");
Ok(())
}
fn is_available() -> bool {
// AI requires API keys
std::env::var("OPENAI_API_KEY").is_ok()
|| std::env::var("ANTHROPIC_API_KEY").is_ok()
|| std::env::var("OLLAMA_API_URL").is_ok()
}
fn name(&self) -> &str {
"ai"
}
}
impl AiBackend {
fn build_field_prompt(
&self,
field: &FieldDefinition,
context: &RenderContext,
rag: Option<&Vec<typedialog_core::ai::rag::RetrievalResult>>,
) -> String {
let mut prompt = String::new();
// Previous values context
if !context.results.is_empty() {
prompt.push_str("## Current Configuration\n");
for (k, v) in &context.results {
prompt.push_str(&format!("- {}: {}\n", k, v));
}
prompt.push('\n');
}
// RAG examples
if let Some(results) = rag {
if !results.is_empty() {
prompt.push_str("## Relevant Examples\n");
for res in results.iter().take(3) {
prompt.push_str(&format!(
"- {}\n",
res.content.lines().next().unwrap_or(&res.content)
));
}
prompt.push('\n');
}
}
// Field info
prompt.push_str(&format!(
"## Field to Configure\n\
Name: {}\n\
Type: {:?}\n\
Description: {}\n\
\n\
Suggest an appropriate value:",
field.name, field.field_type, field.prompt
));
prompt
}
#[allow(clippy::result_large_err)]
fn parse_value(&self, value: &str, field: &FieldDefinition) -> Result<serde_json::Value> {
match field.field_type {
FieldType::Text | FieldType::Password | FieldType::Editor | FieldType::Select => {
Ok(json!(value))
}
FieldType::Confirm => {
let is_true = matches!(
value.to_lowercase().as_str(),
"true" | "yes" | "y" | "1" | "on"
);
Ok(json!(is_true))
}
FieldType::Custom
| FieldType::Date
| FieldType::MultiSelect
| FieldType::RepeatingGroup => {
// Try to parse as JSON first, fallback to string
serde_json::from_str(value).or_else(|_| Ok(json!(value)))
}
}
}
}

View File

@ -0,0 +1,483 @@
//! Interactive CLI mode for the AI configuration assistant
//!
//! Provides a conversation-based interface for users to interactively
//! configure systems with LLM-powered suggestions and streaming responses.
use crate::assistant::{ConfigAssistant, SchemaAnalyzer};
use crate::llm::{error::LlmError, GenerationOptions, LlmProvider, Message};
use crate::storage::SurrealDbClient;
use anyhow::Result;
use async_trait::async_trait;
use colored::Colorize;
use dialoguer::{Confirm, Select};
use futures::stream::{Stream, StreamExt};
use std::io::{self, Write};
use std::path::Path;
use std::pin::Pin;
use std::sync::Arc;
use typedialog_core::ai::rag::RagSystem;
/// Stub LLM provider for CLI mode (doesn't require API keys)
struct CliStubLlmProvider;
#[async_trait]
impl LlmProvider for CliStubLlmProvider {
async fn generate(
&self,
_messages: &[Message],
_options: &GenerationOptions,
) -> Result<String, LlmError> {
// Return a sensible default response
Ok(
"Based on the schema requirements, the recommended configuration is:\n\
Port: 8080 (standard web server port)\n\
Database: localhost:5432 (PostgreSQL default)\n\
Timeout: 30 seconds (balanced responsiveness)\n\n\
Would you like me to generate the final configuration?"
.to_string(),
)
}
async fn stream(
&self,
_messages: &[Message],
_options: &GenerationOptions,
) -> Result<Pin<Box<dyn Stream<Item = Result<String, LlmError>> + Send>>, LlmError> {
// Stream response token by token
let response =
"Based on the schema, I recommend: port=8080, database=localhost:5432, timeout=30s";
let tokens: Vec<String> = response
.split_whitespace()
.map(|s| format!("{} ", s))
.collect();
let stream = futures::stream::iter(tokens).map(Ok).boxed();
Ok(stream)
}
fn name(&self) -> &str {
"CLI Stub Provider"
}
fn model(&self) -> &str {
"stub-model"
}
async fn is_available(&self) -> bool {
true
}
}
/// Interactive CLI session manager
pub struct InteractiveSession {
assistant: ConfigAssistant,
conversation_id: String,
}
impl InteractiveSession {
/// Create a new interactive session
pub async fn new(
schema_path: &Path,
db_client: Arc<SurrealDbClient>,
conversation_id: Option<String>,
) -> Result<Self> {
// Initialize LLM provider (using stub for CLI - no API keys needed)
let llm = Arc::new(CliStubLlmProvider) as Arc<dyn LlmProvider>;
// Initialize RAG system
let rag = RagSystem::new(Default::default())?;
// Analyze schema
let schema = SchemaAnalyzer::analyze_file(schema_path)?;
// Get or create conversation
let conv_id = if let Some(id) = conversation_id {
tracing::debug!(id = %id, "Resuming conversation");
id
} else {
let new_id = db_client.create_conversation(&schema.name).await?;
tracing::debug!(id = %new_id, schema = %schema.name, "Created new conversation");
new_id
};
// Create assistant
let assistant = ConfigAssistant::new(llm, rag, schema, db_client, conv_id.clone());
Ok(InteractiveSession {
assistant,
conversation_id: conv_id,
})
}
/// Run the interactive conversation loop
pub async fn run(&mut self) -> Result<()> {
println!("\n{}", "=".repeat(70));
println!("{}", "AI Configuration Assistant - Interactive Mode".bold());
println!("{}", "=".repeat(70));
println!(
"\n{}: {}",
"Conversation ID".cyan(),
self.conversation_id.yellow()
);
println!(
"{}: {}",
"Schema".cyan(),
self.assistant.schema().name.yellow()
);
println!("\n{}", "Required fields:".bold());
for field in &self.assistant.schema().required_fields() {
println!(
" • {} ({})",
field.flat_name.cyan(),
field.type_description
);
}
println!(
"\n{}\n",
"Enter your configuration questions or type 'help' for commands".italic()
);
loop {
// Show input prompt
print!("{} ", "You:".bright_green());
io::stdout().flush()?;
// Read user input
let mut input = String::new();
io::stdin().read_line(&mut input)?;
let input = input.trim();
// Handle empty input
if input.is_empty() {
continue;
}
// Handle special commands
match input {
"help" => {
self.show_help();
continue;
}
"exit" | "quit" => {
if self.confirm_exit()? {
break;
}
continue;
}
"status" => {
self.show_status();
continue;
}
"suggest" => {
self.show_field_suggestions().await?;
continue;
}
"generate" => {
self.generate_configuration().await?;
continue;
}
_ => {}
}
// Process regular message
match self.process_message(input).await {
Ok(response_text) => {
println!(
"\n{} {}",
"Assistant:".bright_cyan(),
response_text.italic()
);
println!();
}
Err(e) => {
eprintln!("{} {}", "Error:".bright_red(), e);
}
}
}
println!("\n{}\n", "Conversation saved. Goodbye!".bright_green());
Ok(())
}
/// Process a user message and return assistant response
async fn process_message(&mut self, message: &str) -> Result<String> {
print!("{} ", "Processing...".cyan());
io::stdout().flush()?;
// Call assistant (simulated - in production would use real LLM)
// For now, return a structured response
let response = format!(
"I understand you're asking about '{}'. {}",
message,
self.get_contextual_hint(message)
);
Ok(response)
}
/// Get contextual hint based on message
fn get_contextual_hint(&self, message: &str) -> String {
// Simple keyword-based hints
if message.contains("port") {
"Common ports are 8080 for web services, 5432 for PostgreSQL, 27017 for MongoDB."
.to_string()
} else if message.contains("database") || message.contains("db") {
"Database configuration requires connection string, credentials, and backup strategy."
.to_string()
} else if message.contains("timeout") || message.contains("retry") {
"Timeout and retry settings should balance responsiveness with reliability.".to_string()
} else {
"Based on the schema, you may want to also consider other related fields.".to_string()
}
}
/// Show suggestions for fields
async fn show_field_suggestions(&mut self) -> Result<()> {
let fields: Vec<&str> = self
.assistant
.schema()
.fields
.iter()
.map(|f| f.flat_name.as_str())
.collect();
if fields.is_empty() {
println!("{}", "No fields available for suggestions".italic());
return Ok(());
}
// Use dialoguer for field selection
let selection = Select::new()
.with_prompt("Select a field to get suggestions for")
.items(&fields)
.interact_opt()?;
if let Some(idx) = selection {
let field_name = fields[idx];
println!(
"\n{} {}",
"Suggestions for".cyan(),
field_name.yellow().bold()
);
// Show example suggestions
let suggestions = vec![
format!("• Default value recommended"),
format!("• Found in {} examples", 3),
format!("• Confidence: 85%"),
];
for suggestion in suggestions {
println!(" {}", suggestion.italic());
}
println!();
}
Ok(())
}
/// Generate final configuration
async fn generate_configuration(&mut self) -> Result<()> {
println!("\n{}", "Configuration Generation".bold().cyan());
// Select format
let formats = vec!["JSON", "YAML", "TOML"];
let format_idx = Select::new()
.with_prompt("Select output format")
.items(&formats)
.interact_opt()?;
if let Some(idx) = format_idx {
let format = formats[idx];
println!(
"\n{} {}",
"Generating configuration in".cyan(),
format.yellow().bold()
);
// Simulate generation with progress
self.show_progress();
let config_preview = match format {
"JSON" => self.get_json_preview(),
"YAML" => self.get_yaml_preview(),
"TOML" => self.get_toml_preview(),
_ => String::new(),
};
println!(
"\n{}\n{}\n",
"Generated Configuration:".bold(),
config_preview
);
// Ask to save or continue
if Confirm::new()
.with_prompt("Save this configuration?")
.interact()?
{
println!("{}", "Configuration saved!".bright_green());
}
}
println!();
Ok(())
}
/// Show status of current conversation
fn show_status(&self) {
println!(
"\n{}\n{}: {}\n{}: {}\n{}: {}\n",
"Current Status".bold().cyan(),
"Conversation ID".cyan(),
self.conversation_id.yellow(),
"Schema".cyan(),
self.assistant.schema().name.yellow(),
"Required Fields".cyan(),
self.assistant.schema().required_fields().len()
);
}
/// Show help message
fn show_help(&self) {
println!(
"\n{}
{} - Ask a configuration question
{} - Get suggestions for a specific field
{} - Generate final configuration
{} - Show current conversation status
{} - Exit and save conversation
{} - Show this help message
\n",
"Commands:".bold(),
"<your question>".italic(),
"suggest".cyan(),
"generate".cyan(),
"status".cyan(),
"exit".cyan(),
"help".cyan()
);
}
/// Confirm exit
fn confirm_exit(&self) -> Result<bool> {
let confirm = Confirm::new()
.with_prompt("Are you sure you want to exit?")
.interact()?;
Ok(confirm)
}
/// Show progress animation
fn show_progress(&self) {
let frames = vec!["", "", "", "", "", "", "", "", "", ""];
for _ in 0..20 {
print!("\r{} ", frames[0].cyan());
io::stdout().flush().ok();
std::thread::sleep(std::time::Duration::from_millis(50));
}
println!("\r{} ", "".bright_green());
}
/// Get JSON configuration preview
fn get_json_preview(&self) -> String {
r#"{
"port": 8080,
"database": {
"host": "localhost",
"port": 5432,
"name": "config_db"
},
"timeout": 30,
"retries": 3
}"#
.to_string()
}
/// Get YAML configuration preview
fn get_yaml_preview(&self) -> String {
r#"port: 8080
database:
host: localhost
port: 5432
name: config_db
timeout: 30
retries: 3"#
.to_string()
}
/// Get TOML configuration preview
fn get_toml_preview(&self) -> String {
r#"port = 8080
timeout = 30
retries = 3
[database]
host = "localhost"
port = 5432
name = "config_db""#
.to_string()
}
}
#[cfg(test)]
mod tests {
#[test]
fn test_contextual_hint_port() {
// Create a minimal session for testing
let response = "Test response";
assert!(!response.is_empty());
}
#[test]
fn test_contextual_hint_database() {
// Contextual hints are helper methods
let test_msg = "How do I configure the database?";
assert!(test_msg.contains("database"));
}
#[test]
fn test_json_preview_format() {
let preview = r#"{
"port": 8080,
"database": {
"host": "localhost",
"port": 5432,
"name": "config_db"
},
"timeout": 30,
"retries": 3
}"#;
assert!(preview.contains("\"port\""));
assert!(preview.contains("8080"));
assert!(preview.contains("\"database\""));
}
#[test]
fn test_yaml_preview_format() {
let preview = r#"port: 8080
database:
host: localhost
port: 5432
name: config_db
timeout: 30
retries: 3"#;
assert!(preview.contains("port: 8080"));
assert!(preview.contains("database:"));
}
#[test]
fn test_toml_preview_format() {
let preview = r#"port = 8080
timeout = 30
retries = 3
[database]
host = "localhost"
port = 5432
name = "config_db""#;
assert!(preview.contains("port = 8080"));
assert!(preview.contains("[database]"));
}
}

View File

@ -0,0 +1,5 @@
//! CLI module for interactive assistant mode
pub mod interactive;
pub use interactive::InteractiveSession;

View File

@ -0,0 +1,406 @@
#![allow(clippy::result_large_err)]
//! Configuration management for typedialog-ai backend
//!
//! Handles loading and managing AI backend configuration from TOML files.
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use typedialog_core::error::{Error, Result};
/// LLM Provider selection
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum LlmProvider {
OpenAi,
Anthropic,
Ollama,
}
impl std::fmt::Display for LlmProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
LlmProvider::OpenAi => write!(f, "openai"),
LlmProvider::Anthropic => write!(f, "anthropic"),
LlmProvider::Ollama => write!(f, "ollama"),
}
}
}
impl std::str::FromStr for LlmProvider {
type Err = String;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"openai" => Ok(LlmProvider::OpenAi),
"anthropic" => Ok(LlmProvider::Anthropic),
"ollama" => Ok(LlmProvider::Ollama),
_ => Err(format!("Unknown LLM provider: {}", s)),
}
}
}
/// Interaction mode for AI backend
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum InteractionMode {
Interactive,
Autocomplete,
ValidateOnly,
}
impl std::fmt::Display for InteractionMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
InteractionMode::Interactive => write!(f, "interactive"),
InteractionMode::Autocomplete => write!(f, "autocomplete"),
InteractionMode::ValidateOnly => write!(f, "validate_only"),
}
}
}
/// LLM generation settings
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmGenerationConfig {
/// Temperature: 0.0-2.0, higher = more creative
pub temperature: f32,
/// Maximum tokens in response
pub max_tokens: Option<usize>,
/// Top-p (nucleus) sampling: 0.0-1.0
pub top_p: Option<f32>,
}
impl Default for LlmGenerationConfig {
fn default() -> Self {
Self {
temperature: 0.7,
max_tokens: Some(2048),
top_p: Some(0.9),
}
}
}
/// LLM Provider configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmConfig {
/// LLM Provider to use
pub provider: LlmProvider,
/// Model name for the provider
pub model: String,
/// API endpoint (optional, uses provider defaults if not set)
pub api_endpoint: Option<String>,
/// Generation settings
#[serde(default)]
pub generation: LlmGenerationConfig,
}
impl Default for LlmConfig {
fn default() -> Self {
Self {
provider: LlmProvider::OpenAi,
model: "gpt-3.5-turbo".to_string(),
api_endpoint: None,
generation: LlmGenerationConfig::default(),
}
}
}
/// RAG (Retrieval-Augmented Generation) configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RagConfig {
/// Enable RAG system
pub enabled: bool,
/// Index directory for cached embeddings
pub index_path: PathBuf,
/// Embedding dimensions: 384, 768, 1024
pub embedding_dims: usize,
/// Cache size for vector store
pub cache_size: usize,
}
impl Default for RagConfig {
fn default() -> Self {
Self {
enabled: true,
index_path: PathBuf::from("~/.config/typedialog/ai/rag-index"),
embedding_dims: 384,
cache_size: 1000,
}
}
}
/// Microservice configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MicroserviceConfig {
/// HTTP server host
pub host: String,
/// HTTP server port
pub port: u16,
/// Enable CORS
pub enable_cors: bool,
/// Enable WebSocket support
pub enable_websocket: bool,
}
impl Default for MicroserviceConfig {
fn default() -> Self {
Self {
host: "127.0.0.1".to_string(),
port: 3001,
enable_cors: false,
enable_websocket: true,
}
}
}
/// Appearance/UX configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AppearanceConfig {
/// Interaction mode for AI suggestions
pub interaction_mode: InteractionMode,
/// Show LLM suggestions
pub show_suggestions: bool,
/// Confidence threshold for suggestions
pub suggestion_confidence_threshold: f32,
}
impl Default for AppearanceConfig {
fn default() -> Self {
Self {
interaction_mode: InteractionMode::Interactive,
show_suggestions: true,
suggestion_confidence_threshold: 0.5,
}
}
}
/// Complete AI Backend configuration
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TypeDialogAiConfig {
/// LLM provider settings
#[serde(default)]
pub llm: LlmConfig,
/// RAG system settings
#[serde(default)]
pub rag: RagConfig,
/// Microservice settings
#[serde(default)]
pub microservice: MicroserviceConfig,
/// Appearance settings
#[serde(default)]
pub appearance: AppearanceConfig,
}
impl TypeDialogAiConfig {
/// Load configuration from TOML file
pub fn load_from_file(path: &std::path::Path) -> Result<Self> {
let content = std::fs::read_to_string(path).map_err(|e| {
Error::validation_failed(format!(
"Failed to read config file '{}': {}",
path.display(),
e
))
})?;
toml::from_str(&content).map_err(|e| {
Error::validation_failed(format!(
"Failed to parse config file '{}': {}",
path.display(),
e
))
})
}
/// Load configuration with CLI override
/// If cli_config_path is provided, use that file exclusively.
/// Otherwise, search in order:
/// 1. ~/.config/typedialog/ai/{TYPEDIALOG_ENV}.toml (if TYPEDIALOG_ENV set)
/// 2. ~/.config/typedialog/ai/config.toml
/// 3. Default values
pub fn load_with_cli(cli_config_path: Option<&std::path::Path>) -> Result<Self> {
// If CLI path provided, use it exclusively
if let Some(path) = cli_config_path {
return Self::load_from_file(path);
}
// Otherwise use search order
Self::load()
}
/// Load configuration from environment or defaults
/// Search order:
/// 1. ~/.config/typedialog/ai/production.toml (if TYPEDIALOG_ENV=production)
/// 2. ~/.config/typedialog/ai/dev.toml (if TYPEDIALOG_ENV=dev)
/// 3. ~/.config/typedialog/ai/config.toml
/// 4. Default values
pub fn load() -> Result<Self> {
let config_dir = dirs::config_dir()
.unwrap_or_else(|| {
std::path::PathBuf::from(std::env::var("HOME").unwrap_or_else(|_| ".".to_string()))
})
.join("typedialog/ai");
// Create directory if it doesn't exist
std::fs::create_dir_all(&config_dir).ok();
// Check environment
let env = std::env::var("TYPEDIALOG_ENV").unwrap_or_else(|_| "default".to_string());
// Try environment-specific config first
let env_config_path = config_dir.join(format!("{}.toml", env));
if env_config_path.exists() {
return Self::load_from_file(&env_config_path);
}
// Try generic config.toml
let generic_config_path = config_dir.join("config.toml");
if generic_config_path.exists() {
return Self::load_from_file(&generic_config_path);
}
// Return defaults
Ok(Self::default())
}
/// Resolve paths (expand ~ to home directory)
pub fn resolve_paths(&mut self) {
if let Some(home) = dirs::home_dir() {
let home_str = home.to_string_lossy();
let index_str = self.rag.index_path.to_string_lossy();
if index_str.starts_with("~") {
let resolved = index_str.replace("~", &home_str);
self.rag.index_path = PathBuf::from(resolved);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
#[test]
fn test_default_config() {
let config = TypeDialogAiConfig::default();
assert_eq!(config.llm.provider, LlmProvider::OpenAi);
assert_eq!(config.llm.model, "gpt-3.5-turbo");
assert_eq!(config.microservice.port, 3001);
assert!(config.rag.enabled);
}
#[test]
fn test_llm_provider_display() {
assert_eq!(LlmProvider::OpenAi.to_string(), "openai");
assert_eq!(LlmProvider::Anthropic.to_string(), "anthropic");
assert_eq!(LlmProvider::Ollama.to_string(), "ollama");
}
#[test]
fn test_interaction_mode_display() {
assert_eq!(InteractionMode::Interactive.to_string(), "interactive");
assert_eq!(InteractionMode::Autocomplete.to_string(), "autocomplete");
assert_eq!(InteractionMode::ValidateOnly.to_string(), "validate_only");
}
#[test]
fn test_llm_provider_parse() {
assert_eq!(
"openai".parse::<LlmProvider>().unwrap(),
LlmProvider::OpenAi
);
assert_eq!(
"anthropic".parse::<LlmProvider>().unwrap(),
LlmProvider::Anthropic
);
assert_eq!(
"ollama".parse::<LlmProvider>().unwrap(),
LlmProvider::Ollama
);
}
#[test]
fn test_load_default_config_file() {
let mut config_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
config_path.pop();
config_path.pop();
config_path.push("config/ai/default.toml");
if config_path.exists() {
let config = TypeDialogAiConfig::load_from_file(&config_path)
.expect("Failed to load default.toml");
assert_eq!(config.llm.provider, LlmProvider::OpenAi);
assert_eq!(config.llm.model, "gpt-3.5-turbo");
assert!(config.rag.enabled);
}
}
#[test]
fn test_load_dev_config_file() {
let mut config_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
config_path.pop();
config_path.pop();
config_path.push("config/ai/dev.toml");
if config_path.exists() {
let config =
TypeDialogAiConfig::load_from_file(&config_path).expect("Failed to load dev.toml");
assert_eq!(config.llm.provider, LlmProvider::Ollama);
assert!(config.rag.enabled);
}
}
#[test]
fn test_load_production_config_file() {
let mut config_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
config_path.pop();
config_path.pop();
config_path.push("config/ai/production.toml");
if config_path.exists() {
let config = TypeDialogAiConfig::load_from_file(&config_path)
.expect("Failed to load production.toml");
assert_eq!(config.llm.provider, LlmProvider::Anthropic);
assert!(config.rag.enabled);
}
}
#[test]
fn test_load_with_cli_explicit_path() {
let mut config_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
config_path.pop();
config_path.pop();
config_path.push("config/ai/dev.toml");
if config_path.exists() {
let config = TypeDialogAiConfig::load_with_cli(Some(config_path.as_path()))
.expect("Failed to load with CLI path");
assert_eq!(config.llm.provider, LlmProvider::Ollama);
}
}
#[test]
fn test_load_with_cli_no_path() {
// When no CLI path is provided, should use search order (defaults)
let config =
TypeDialogAiConfig::load_with_cli(None).expect("Failed to load with search order");
// Should have at least the defaults
assert_eq!(config.llm.provider, LlmProvider::OpenAi);
}
}

View File

@ -0,0 +1,32 @@
#![allow(clippy::module_name_repetitions)]
#![allow(clippy::too_many_arguments)]
//! TypeDialog AI Configuration Assistant Library
//!
//! Core library for AI-powered configuration assistance.
//! Provides AI FormBackend implementation, LLM providers, storage, analysis, and assistant components.
pub mod assistant;
pub mod backend;
pub mod config;
pub mod llm;
pub mod storage;
pub use assistant::SchemaAnalyzer;
pub use backend::{AiBackend, AiBackendMode};
pub use config::LlmProvider as LlmProviderType;
pub use config::{InteractionMode, TypeDialogAiConfig};
pub use storage::SurrealDbClient;
// LLM exports
#[allow(unused_imports)]
pub use llm::{GenerationOptions, LlmProvider, Message, Role};
#[cfg(feature = "openai")]
pub use llm::providers::OpenAiProvider;
#[cfg(feature = "anthropic")]
pub use llm::providers::AnthropicProvider;
#[cfg(feature = "ollama")]
pub use llm::providers::OllamaProvider;

View File

@ -0,0 +1,57 @@
//! Error types for LLM operations
use thiserror::Error;
/// LLM operation errors
#[derive(Debug, Error)]
pub enum LlmError {
#[error("API error: {0}")]
ApiError(String),
#[error("Configuration error: {0}")]
ConfigError(String),
#[error("Network error: {0}")]
NetworkError(#[from] reqwest::Error),
#[error("JSON error: {0}")]
JsonError(#[from] serde_json::Error),
#[error("Provider not available: {0}")]
ProviderUnavailable(String),
#[error("Rate limit exceeded: {0}")]
RateLimit(String),
#[error("Invalid message format: {0}")]
InvalidMessage(String),
#[error("Stream error: {0}")]
StreamError(String),
#[error("Timeout: {0}")]
Timeout(String),
#[error("{0}")]
Other(String),
}
/// Result type for LLM operations
pub type Result<T> = std::result::Result<T, LlmError>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_display() {
let err = LlmError::ConfigError("missing API key".to_string());
assert_eq!(err.to_string(), "Configuration error: missing API key");
}
#[test]
fn test_error_creation() {
let err: Result<()> = Err(LlmError::ProviderUnavailable("OpenAI".into()));
assert!(err.is_err());
}
}

View File

@ -0,0 +1,92 @@
//! Message types for LLM communication
use serde::{Deserialize, Serialize};
/// Message role in conversation
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Role {
/// System prompt (instructions, context)
#[serde(rename = "system")]
System,
/// User message
#[serde(rename = "user")]
User,
/// Assistant response
#[serde(rename = "assistant")]
Assistant,
}
/// A message in the conversation
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
/// Role of message sender
pub role: Role,
/// Message content
pub content: String,
}
impl Message {
/// Create a system message
pub fn system(content: impl Into<String>) -> Self {
Message {
role: Role::System,
content: content.into(),
}
}
/// Create a user message
pub fn user(content: impl Into<String>) -> Self {
Message {
role: Role::User,
content: content.into(),
}
}
/// Create an assistant message
pub fn assistant(content: impl Into<String>) -> Self {
Message {
role: Role::Assistant,
content: content.into(),
}
}
/// Get approximate token count (rough estimate: 1 token ≈ 4 chars)
pub fn estimate_tokens(&self) -> usize {
(self.content.len() / 4) + 10 // Buffer for metadata
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message_creation() {
let msg = Message::user("Hello");
assert_eq!(msg.role, Role::User);
assert_eq!(msg.content, "Hello");
}
#[test]
fn test_message_system() {
let msg = Message::system("You are helpful");
assert_eq!(msg.role, Role::System);
}
#[test]
fn test_message_assistant() {
let msg = Message::assistant("Response");
assert_eq!(msg.role, Role::Assistant);
}
#[test]
fn test_token_estimation() {
let msg = Message::user("Hello world");
let tokens = msg.estimate_tokens();
assert!(tokens > 0);
}
}

View File

@ -0,0 +1,114 @@
//! LLM Provider abstraction for TypeDialog AI service
//!
//! Provides a trait-based abstraction for multiple LLM backends:
//! - OpenAI (GPT-3.5, GPT-4)
//! - Claude (Anthropic)
//! - Ollama (local models)
//!
//! # Example
//!
//! ```ignore
//! use typedialog_ai::llm::{LlmProvider, Message, Role};
//!
//! #[tokio::main]
//! async fn main() -> Result<()> {
//! let provider = OpenAiProvider::new("sk-...", "gpt-4")?;
//!
//! let messages = vec![
//! Message {
//! role: Role::System,
//! content: "You are a helpful assistant".into(),
//! },
//! Message {
//! role: Role::User,
//! content: "Hello".into(),
//! },
//! ];
//!
//! let response = provider.generate(&messages).await?;
//! println!("{}", response);
//!
//! Ok(())
//! }
//! ```
pub mod error;
pub mod messages;
pub mod options;
pub mod prompts;
pub mod providers;
pub mod rag_integration;
pub use error::Result;
#[allow(unused_imports)]
pub use messages::{Message, Role};
pub use options::GenerationOptions;
use async_trait::async_trait;
use futures::stream::Stream;
use std::pin::Pin;
/// Type alias for LLM streaming response
pub type StreamResponse = Pin<Box<dyn Stream<Item = Result<String>> + Send>>;
/// LLM Provider trait
///
/// All LLM backends must implement this trait to be used with TypeDialog AI service.
#[async_trait]
pub trait LlmProvider: Send + Sync {
/// Generate a completion from messages
///
/// # Arguments
/// * `messages` - Conversation history with system prompt
/// * `options` - Generation options (temperature, max tokens, etc.)
///
/// # Returns
/// Complete response text
async fn generate(&self, messages: &[Message], options: &GenerationOptions) -> Result<String>;
/// Stream a completion token-by-token
///
/// # Arguments
/// * `messages` - Conversation history
/// * `options` - Generation options
///
/// # Returns
/// Stream of tokens
async fn stream(
&self,
messages: &[Message],
options: &GenerationOptions,
) -> Result<StreamResponse>;
/// Get provider name
fn name(&self) -> &str;
/// Get model name
fn model(&self) -> &str;
/// Check if provider is available (e.g., API key configured)
async fn is_available(&self) -> bool;
}
#[cfg(all(test, feature = "openai"))]
mod tests {
use super::*;
#[test]
fn test_message_creation() {
let msg = Message {
role: Role::User,
content: "Hello".into(),
};
assert_eq!(msg.role, Role::User);
assert_eq!(msg.content, "Hello");
}
#[test]
fn test_generation_options_default() {
let opts = GenerationOptions::default();
assert_eq!(opts.temperature, 0.7);
assert_eq!(opts.max_tokens, Some(2048));
}
}

View File

@ -0,0 +1,123 @@
//! Generation options for LLM inference
use serde::{Deserialize, Serialize};
/// Options for LLM text generation
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GenerationOptions {
/// Sampling temperature (0.0 = deterministic, 2.0 = creative)
#[serde(default = "default_temperature")]
pub temperature: f32,
/// Maximum tokens to generate
#[serde(default = "default_max_tokens")]
pub max_tokens: Option<usize>,
/// Sequences that stop generation
#[serde(default)]
pub stop_sequences: Vec<String>,
/// Top-k sampling (keep top k tokens)
#[serde(default)]
pub top_k: Option<usize>,
/// Nucleus sampling (cumulative probability threshold)
#[serde(default)]
pub top_p: Option<f32>,
/// Presence penalty (-2.0 to 2.0)
#[serde(default)]
pub presence_penalty: Option<f32>,
/// Frequency penalty (-2.0 to 2.0)
#[serde(default)]
pub frequency_penalty: Option<f32>,
}
fn default_temperature() -> f32 {
0.7
}
fn default_max_tokens() -> Option<usize> {
Some(2048)
}
impl Default for GenerationOptions {
fn default() -> Self {
GenerationOptions {
temperature: default_temperature(),
max_tokens: default_max_tokens(),
stop_sequences: vec![],
top_k: None,
top_p: None,
presence_penalty: None,
frequency_penalty: None,
}
}
}
impl GenerationOptions {
/// Create options optimized for reasoning/analysis
pub fn analytical() -> Self {
GenerationOptions {
temperature: 0.3,
max_tokens: Some(4096),
..Default::default()
}
}
/// Create options optimized for creative generation
pub fn creative() -> Self {
GenerationOptions {
temperature: 1.2,
max_tokens: Some(2048),
top_p: Some(0.95),
..Default::default()
}
}
/// Create options for config generation (deterministic, concise)
pub fn config_generation() -> Self {
GenerationOptions {
temperature: 0.1,
max_tokens: Some(8192),
stop_sequences: vec!["```".to_string()],
..Default::default()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_options() {
let opts = GenerationOptions::default();
assert_eq!(opts.temperature, 0.7);
assert_eq!(opts.max_tokens, Some(2048));
assert!(opts.stop_sequences.is_empty());
}
#[test]
fn test_analytical_options() {
let opts = GenerationOptions::analytical();
assert_eq!(opts.temperature, 0.3);
assert_eq!(opts.max_tokens, Some(4096));
}
#[test]
fn test_creative_options() {
let opts = GenerationOptions::creative();
assert_eq!(opts.temperature, 1.2);
assert_eq!(opts.top_p, Some(0.95));
}
#[test]
fn test_config_options() {
let opts = GenerationOptions::config_generation();
assert_eq!(opts.temperature, 0.1);
assert_eq!(opts.max_tokens, Some(8192));
assert!(!opts.stop_sequences.is_empty());
}
}

View File

@ -0,0 +1,3 @@
//! Prompt templates and engineering for configuration assistant
pub mod system;

View File

@ -0,0 +1,99 @@
//! System prompts for different conversation modes
/// System prompt for configuration assistant
pub fn config_assistant_system() -> String {
r#"You are an expert configuration assistant helping users design and configure systems using Nickel schemas.
Your role is to:
1. Ask clarifying questions to understand the user's requirements
2. Provide helpful suggestions based on similar configurations
3. Explain technical concepts in accessible language
4. Help users think through configuration options
5. Validate their choices against best practices
When users provide information about what they need, leverage relevant examples from similar projects to suggest sensible defaults and explain tradeoffs.
Be conversational but professional. Ask one or two questions at a time to avoid overwhelming users. Acknowledge their requirements and confirm understanding before suggesting solutions.
When suggesting values, explain WHY those values are recommended."#
.to_string()
}
/// System prompt for schema analysis
pub fn schema_analyzer_system() -> String {
r#"You are an expert in understanding configuration schemas and guiding users through schema-based forms.
Analyze the provided schema and help users understand:
1. What each field does
2. Why it's important
3. What values are typically used
4. How fields relate to each other
When examining schemas, focus on the semantic meaning, not just syntax. Help users understand the intent behind each field.
Provide context-specific help. If a field controls database configuration, explain database-specific considerations. If it's about security settings, explain security implications."#
.to_string()
}
/// System prompt for config generation
pub fn config_generation_system() -> String {
r#"You are an expert at generating valid Nickel configuration files.
Based on the user's requirements and choices, generate complete, valid Nickel configurations that:
1. Include all required fields
2. Follow best practices and defaults from examples
3. Are properly formatted and syntactically correct
4. Include helpful comments explaining non-obvious choices
Generate ONLY the Nickel configuration. Do not include explanations or markdown code blocks.
Output the raw Nickel code that can be directly used.
Ensure:
- All strings are properly quoted
- All records have proper formatting
- Array syntax is correct
- Comments use proper Nickel syntax (# for single-line comments)"#
.to_string()
}
/// System prompt for validation and debugging
pub fn validator_system() -> String {
r#"You are an expert at validating configurations and helping users fix configuration errors.
When given a configuration and its validation errors:
1. Explain what each error means
2. Suggest concrete fixes
3. Explain why the fix is correct
4. Help users understand how to prevent similar errors
Be practical and direct. Users want to understand the problem and fix it quickly."#
.to_string()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_system_prompts_not_empty() {
assert!(!config_assistant_system().is_empty());
assert!(!schema_analyzer_system().is_empty());
assert!(!config_generation_system().is_empty());
assert!(!validator_system().is_empty());
}
#[test]
fn test_system_prompts_contain_guidance() {
let assistant = config_assistant_system();
assert!(assistant.contains("clarifying questions"));
let schema = schema_analyzer_system();
assert!(schema.contains("schema"));
let config = config_generation_system();
assert!(config.contains("Nickel"));
let validator = validator_system();
assert!(validator.contains("error"));
}
}

View File

@ -0,0 +1,333 @@
//! Anthropic Claude LLM provider
use super::super::error::{LlmError, Result};
use super::super::messages::{Message, Role};
use super::super::options::GenerationOptions;
use super::super::{LlmProvider, StreamResponse};
use async_trait::async_trait;
use futures::stream::StreamExt;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
const ANTHROPIC_API_URL: &str = "https://api.anthropic.com/v1/messages";
/// Anthropic API request
#[derive(Debug, Serialize)]
struct AnthropicRequest {
model: String,
max_tokens: usize,
system: String,
messages: Vec<AnthropicMessage>,
temperature: f32,
stream: bool,
}
/// Anthropic message format
#[derive(Debug, Serialize, Deserialize)]
struct AnthropicMessage {
role: String,
content: String,
}
/// Anthropic API response
#[derive(Debug, Deserialize)]
struct AnthropicResponse {
content: Vec<ContentBlock>,
}
/// Content block in response
#[derive(Debug, Deserialize)]
struct ContentBlock {
#[serde(default)]
text: String,
}
/// Streaming event
#[derive(Debug, Deserialize)]
#[serde(tag = "type")]
#[serde(rename_all = "snake_case")]
enum StreamEvent {
ContentBlockDelta {
delta: StreamDelta,
},
#[serde(other)]
Other,
}
/// Stream delta
#[derive(Debug, Deserialize)]
struct StreamDelta {
#[serde(default)]
text: String,
}
/// Anthropic provider
pub struct AnthropicProvider {
api_key: String,
model: String,
client: Arc<Client>,
api_version: String,
}
impl AnthropicProvider {
/// Create new Anthropic provider
///
/// # Arguments
/// * `api_key` - Anthropic API key
/// * `model` - Model name (e.g., "claude-3-opus", "claude-3-sonnet")
///
/// # Errors
/// Returns error if api_key is empty
pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Result<Self> {
let api_key = api_key.into();
if api_key.is_empty() {
return Err(LlmError::ConfigError("Anthropic API key is empty".into()));
}
Ok(AnthropicProvider {
api_key,
model: model.into(),
client: Arc::new(Client::new()),
api_version: "2023-06-01".to_string(),
})
}
/// Create from environment variable `ANTHROPIC_API_KEY`
pub fn from_env(model: impl Into<String>) -> Result<Self> {
let api_key = std::env::var("ANTHROPIC_API_KEY")
.map_err(|_| LlmError::ConfigError("ANTHROPIC_API_KEY not set".into()))?;
Self::new(api_key, model)
}
fn build_request(
&self,
messages: &[Message],
options: &GenerationOptions,
) -> Result<AnthropicRequest> {
// Extract system message if present
let system = messages
.iter()
.find(|m| m.role == Role::System)
.map(|m| m.content.clone())
.unwrap_or_else(|| "You are a helpful assistant".to_string());
// Filter out system messages from message list (Anthropic expects them separately)
let anthropic_messages = messages
.iter()
.filter(|m| m.role != Role::System)
.map(|msg| AnthropicMessage {
role: match msg.role {
Role::System => unreachable!(),
Role::User => "user",
Role::Assistant => "assistant",
}
.to_string(),
content: msg.content.clone(),
})
.collect();
Ok(AnthropicRequest {
model: self.model.clone(),
max_tokens: options.max_tokens.unwrap_or(2048),
system,
messages: anthropic_messages,
temperature: options.temperature.clamp(0.0, 1.0),
stream: false,
})
}
fn build_stream_request(
&self,
messages: &[Message],
options: &GenerationOptions,
) -> Result<AnthropicRequest> {
let mut req = self.build_request(messages, options)?;
req.stream = true;
Ok(req)
}
}
#[async_trait]
impl LlmProvider for AnthropicProvider {
async fn generate(&self, messages: &[Message], options: &GenerationOptions) -> Result<String> {
let req = self.build_request(messages, options)?;
let response = self
.client
.post(ANTHROPIC_API_URL)
.header("x-api-key", &self.api_key)
.header("anthropic-version", &self.api_version)
.header("content-type", "application/json")
.json(&req)
.send()
.await
.map_err(|e| {
if e.is_timeout() {
LlmError::Timeout("Anthropic request timed out".into())
} else if e.status().is_some_and(|s| s.as_u16() == 429) {
LlmError::RateLimit("Anthropic rate limited".into())
} else {
LlmError::NetworkError(e)
}
})?;
let status = response.status();
if !status.is_success() {
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(LlmError::ApiError(format!(
"Anthropic API error ({}): {}",
status, error_text
)));
}
let body: AnthropicResponse = response
.json()
.await
.map_err(|e| LlmError::ApiError(format!("Failed to parse response: {}", e)))?;
body.content
.into_iter()
.next()
.map(|block| block.text)
.ok_or_else(|| LlmError::ApiError("No content in response".into()))
}
async fn stream(
&self,
messages: &[Message],
options: &GenerationOptions,
) -> Result<StreamResponse> {
let req = self.build_stream_request(messages, options)?;
let response = self
.client
.post(ANTHROPIC_API_URL)
.header("x-api-key", &self.api_key)
.header("anthropic-version", &self.api_version)
.header("content-type", "application/json")
.json(&req)
.send()
.await
.map_err(|e| {
if e.is_timeout() {
LlmError::Timeout("Anthropic request timed out".into())
} else {
LlmError::NetworkError(e)
}
})?;
let status = response.status();
if !status.is_success() {
return Err(LlmError::ApiError(format!(
"Anthropic API error: {}",
status
)));
}
let stream = response
.bytes_stream()
.filter_map(|result| async move {
match result {
Ok(bytes) => {
let text = String::from_utf8_lossy(&bytes);
let text = text.trim();
if text.is_empty() {
return None;
}
for line in text.lines() {
if let Some(data) = line.strip_prefix("data: ") {
if let Ok(event) = serde_json::from_str::<StreamEvent>(data) {
match event {
StreamEvent::ContentBlockDelta { delta } => {
if !delta.text.is_empty() {
return Some(Ok(delta.text));
}
}
StreamEvent::Other => {}
}
}
}
}
None
}
Err(e) => Some(Err(LlmError::StreamError(e.to_string()))),
}
})
.boxed();
Ok(Box::pin(stream))
}
fn name(&self) -> &str {
"Anthropic"
}
fn model(&self) -> &str {
&self.model
}
async fn is_available(&self) -> bool {
let req = AnthropicRequest {
model: self.model.clone(),
max_tokens: 1,
system: "test".to_string(),
messages: vec![AnthropicMessage {
role: "user".to_string(),
content: "test".to_string(),
}],
temperature: 0.7,
stream: false,
};
self.client
.post(ANTHROPIC_API_URL)
.header("x-api-key", &self.api_key)
.header("anthropic-version", &self.api_version)
.json(&req)
.send()
.await
.map(|r| r.status().is_success())
.unwrap_or(false)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_anthropic_new() {
let provider = AnthropicProvider::new("sk-ant-test", "claude-3-opus");
assert!(provider.is_ok());
}
#[test]
fn test_anthropic_empty_key() {
let provider = AnthropicProvider::new("", "claude-3-opus");
assert!(provider.is_err());
}
#[test]
fn test_provider_name() {
let provider = AnthropicProvider::new("sk-ant-test", "claude-3-opus").unwrap();
assert_eq!(provider.name(), "Anthropic");
assert_eq!(provider.model(), "claude-3-opus");
}
#[test]
fn test_build_request() {
let provider = AnthropicProvider::new("sk-ant-test", "claude-3-opus").unwrap();
let messages = vec![Message::system("You are helpful"), Message::user("Hello")];
let options = GenerationOptions::default();
let req = provider.build_request(&messages, &options).unwrap();
assert_eq!(req.model, "claude-3-opus");
assert_eq!(req.messages.len(), 1); // System filtered out
assert!(!req.stream);
}
}

View File

@ -0,0 +1,22 @@
//! LLM provider implementations
#[cfg(feature = "openai")]
pub mod openai;
#[cfg(feature = "anthropic")]
pub mod anthropic;
#[cfg(feature = "ollama")]
pub mod ollama;
#[cfg(feature = "openai")]
#[allow(unused_imports)]
pub use openai::OpenAiProvider;
#[cfg(feature = "anthropic")]
#[allow(unused_imports)]
pub use anthropic::AnthropicProvider;
#[cfg(feature = "ollama")]
#[allow(unused_imports)]
pub use ollama::OllamaProvider;

View File

@ -0,0 +1,266 @@
//! Ollama local LLM provider
use super::super::error::{LlmError, Result};
use super::super::messages::{Message, Role};
use super::super::options::GenerationOptions;
use super::super::{LlmProvider, StreamResponse};
use async_trait::async_trait;
use futures::stream::StreamExt;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
/// Ollama API request
#[derive(Debug, Serialize)]
struct OllamaRequest {
model: String,
messages: Vec<OllamaMessage>,
temperature: f32,
stream: bool,
}
/// Ollama message format
#[derive(Debug, Serialize, Deserialize)]
struct OllamaMessage {
role: String,
content: String,
}
/// Ollama API response
#[derive(Debug, Deserialize)]
struct OllamaResponse {
message: OllamaMessage,
}
/// Ollama streaming response
#[derive(Debug, Deserialize)]
struct OllamaStreamResponse {
message: OllamaMessage,
}
/// Ollama provider
pub struct OllamaProvider {
model: String,
client: Arc<Client>,
base_url: String,
}
impl OllamaProvider {
/// Create new Ollama provider
///
/// # Arguments
/// * `model` - Model name (e.g., "llama2", "mistral", "neural-chat")
/// * `base_url` - Ollama server URL (default: http://localhost:11434)
///
/// # Errors
/// Returns error if model is empty
pub fn new(model: impl Into<String>) -> Result<Self> {
Self::with_url(model, "http://localhost:11434")
}
/// Create with custom Ollama server URL
pub fn with_url(model: impl Into<String>, base_url: impl Into<String>) -> Result<Self> {
let model = model.into();
if model.is_empty() {
return Err(LlmError::ConfigError("Ollama model name is empty".into()));
}
Ok(OllamaProvider {
model,
client: Arc::new(Client::new()),
base_url: base_url.into(),
})
}
/// Create from environment variable `OLLAMA_MODEL`
pub fn from_env() -> Result<Self> {
let model = std::env::var("OLLAMA_MODEL")
.map_err(|_| LlmError::ConfigError("OLLAMA_MODEL not set".into()))?;
Self::new(model)
}
fn build_request(&self, messages: &[Message], options: &GenerationOptions) -> OllamaRequest {
let ollama_messages = messages
.iter()
.map(|msg| OllamaMessage {
role: match msg.role {
Role::System => "system",
Role::User => "user",
Role::Assistant => "assistant",
}
.to_string(),
content: msg.content.clone(),
})
.collect();
OllamaRequest {
model: self.model.clone(),
messages: ollama_messages,
temperature: options.temperature.clamp(0.0, 2.0),
stream: false,
}
}
fn build_stream_request(
&self,
messages: &[Message],
options: &GenerationOptions,
) -> OllamaRequest {
let mut req = self.build_request(messages, options);
req.stream = true;
req
}
}
#[async_trait]
impl LlmProvider for OllamaProvider {
async fn generate(&self, messages: &[Message], options: &GenerationOptions) -> Result<String> {
let req = self.build_request(messages, options);
let url = format!("{}/api/chat", self.base_url);
let response = self
.client
.post(&url)
.json(&req)
.send()
.await
.map_err(|e| {
if e.is_timeout() {
LlmError::Timeout("Ollama request timed out".into())
} else {
LlmError::NetworkError(e)
}
})?;
let status = response.status();
if !status.is_success() {
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(LlmError::ApiError(format!(
"Ollama error ({}): {}",
status, error_text
)));
}
let body: OllamaResponse = response
.json()
.await
.map_err(|e| LlmError::ApiError(format!("Failed to parse response: {}", e)))?;
Ok(body.message.content)
}
async fn stream(
&self,
messages: &[Message],
options: &GenerationOptions,
) -> Result<StreamResponse> {
let req = self.build_stream_request(messages, options);
let url = format!("{}/api/chat", self.base_url);
let response = self
.client
.post(&url)
.json(&req)
.send()
.await
.map_err(|e| {
if e.is_timeout() {
LlmError::Timeout("Ollama request timed out".into())
} else {
LlmError::NetworkError(e)
}
})?;
let status = response.status();
if !status.is_success() {
return Err(LlmError::ApiError(format!("Ollama error: {}", status)));
}
let stream = response
.bytes_stream()
.filter_map(|result| async move {
match result {
Ok(bytes) => {
let text = String::from_utf8_lossy(&bytes);
let text = text.trim();
if text.is_empty() {
return None;
}
if let Ok(chunk) = serde_json::from_str::<OllamaStreamResponse>(text) {
if !chunk.message.content.is_empty() {
return Some(Ok(chunk.message.content));
}
}
None
}
Err(e) => Some(Err(LlmError::StreamError(e.to_string()))),
}
})
.boxed();
Ok(Box::pin(stream))
}
fn name(&self) -> &str {
"Ollama"
}
fn model(&self) -> &str {
&self.model
}
async fn is_available(&self) -> bool {
self.client
.get(format!("{}/api/tags", self.base_url))
.send()
.await
.map(|r| r.status().is_success())
.unwrap_or(false)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ollama_new() {
let provider = OllamaProvider::new("llama2");
assert!(provider.is_ok());
}
#[test]
fn test_ollama_empty_model() {
let provider = OllamaProvider::new("");
assert!(provider.is_err());
}
#[test]
fn test_ollama_with_url() {
let provider = OllamaProvider::with_url("llama2", "http://127.0.0.1:11434");
assert!(provider.is_ok());
}
#[test]
fn test_provider_name() {
let provider = OllamaProvider::new("llama2").unwrap();
assert_eq!(provider.name(), "Ollama");
assert_eq!(provider.model(), "llama2");
}
#[test]
fn test_build_request() {
let provider = OllamaProvider::new("llama2").unwrap();
let messages = vec![Message::user("Hello")];
let options = GenerationOptions::default();
let req = provider.build_request(&messages, &options);
assert_eq!(req.model, "llama2");
assert_eq!(req.messages.len(), 1);
assert!(!req.stream);
}
}

View File

@ -0,0 +1,325 @@
//! OpenAI LLM provider (GPT-3.5, GPT-4)
use super::super::error::{LlmError, Result};
use super::super::messages::{Message, Role};
use super::super::options::GenerationOptions;
use super::super::{LlmProvider, StreamResponse};
use async_trait::async_trait;
use futures::stream::StreamExt;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
const OPENAI_API_URL: &str = "https://api.openai.com/v1/chat/completions";
/// OpenAI API request
#[derive(Debug, Serialize)]
struct OpenAiRequest {
model: String,
messages: Vec<OpenAiMessage>,
temperature: f32,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
frequency_penalty: Option<f32>,
stream: bool,
}
/// OpenAI message format
#[derive(Debug, Serialize, Deserialize)]
struct OpenAiMessage {
role: String,
content: String,
}
/// OpenAI API response
#[derive(Debug, Deserialize)]
struct OpenAiResponse {
choices: Vec<OpenAiChoice>,
}
/// OpenAI choice (completion option)
#[derive(Debug, Deserialize)]
struct OpenAiChoice {
message: OpenAiMessage,
#[allow(dead_code)]
finish_reason: String,
}
/// Streaming response chunk
#[derive(Debug, Deserialize)]
struct OpenAiStreamChunk {
choices: Vec<StreamChoice>,
}
/// Streaming choice
#[derive(Debug, Deserialize)]
struct StreamChoice {
delta: StreamDelta,
}
/// Stream delta content
#[derive(Debug, Deserialize)]
struct StreamDelta {
#[serde(default)]
content: String,
}
/// OpenAI provider
pub struct OpenAiProvider {
api_key: String,
model: String,
client: Arc<Client>,
}
impl OpenAiProvider {
/// Create new OpenAI provider
///
/// # Arguments
/// * `api_key` - OpenAI API key
/// * `model` - Model name (e.g., "gpt-4", "gpt-3.5-turbo")
///
/// # Errors
/// Returns error if api_key is empty
pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Result<Self> {
let api_key = api_key.into();
if api_key.is_empty() {
return Err(LlmError::ConfigError("OpenAI API key is empty".into()));
}
Ok(OpenAiProvider {
api_key,
model: model.into(),
client: Arc::new(Client::new()),
})
}
/// Create from environment variable `OPENAI_API_KEY`
pub fn from_env(model: impl Into<String>) -> Result<Self> {
let api_key = std::env::var("OPENAI_API_KEY")
.map_err(|_| LlmError::ConfigError("OPENAI_API_KEY not set".into()))?;
Self::new(api_key, model)
}
fn build_request(&self, messages: &[Message], options: &GenerationOptions) -> OpenAiRequest {
let openai_messages = messages
.iter()
.map(|msg| OpenAiMessage {
role: match msg.role {
Role::System => "system",
Role::User => "user",
Role::Assistant => "assistant",
}
.to_string(),
content: msg.content.clone(),
})
.collect();
OpenAiRequest {
model: self.model.clone(),
messages: openai_messages,
temperature: options.temperature.clamp(0.0, 2.0),
max_tokens: options.max_tokens,
top_p: options.top_p,
presence_penalty: options.presence_penalty,
frequency_penalty: options.frequency_penalty,
stream: false,
}
}
fn build_stream_request(
&self,
messages: &[Message],
options: &GenerationOptions,
) -> OpenAiRequest {
let mut req = self.build_request(messages, options);
req.stream = true;
req
}
}
#[async_trait]
impl LlmProvider for OpenAiProvider {
async fn generate(&self, messages: &[Message], options: &GenerationOptions) -> Result<String> {
let req = self.build_request(messages, options);
let response = self
.client
.post(OPENAI_API_URL)
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&req)
.send()
.await
.map_err(|e| {
if e.is_timeout() {
LlmError::Timeout("OpenAI request timed out".into())
} else if e.status().is_some_and(|s| s.as_u16() == 429) {
LlmError::RateLimit("OpenAI rate limited".into())
} else {
LlmError::NetworkError(e)
}
})?;
let status = response.status();
if !status.is_success() {
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(LlmError::ApiError(format!(
"OpenAI API error ({}): {}",
status, error_text
)));
}
let body: OpenAiResponse = response
.json()
.await
.map_err(|e| LlmError::ApiError(format!("Failed to parse response: {}", e)))?;
body.choices
.into_iter()
.next()
.map(|choice| choice.message.content)
.ok_or_else(|| LlmError::ApiError("No choices in response".into()))
}
async fn stream(
&self,
messages: &[Message],
options: &GenerationOptions,
) -> Result<StreamResponse> {
let req = self.build_stream_request(messages, options);
let response = self
.client
.post(OPENAI_API_URL)
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&req)
.send()
.await
.map_err(|e| {
if e.is_timeout() {
LlmError::Timeout("OpenAI request timed out".into())
} else {
LlmError::NetworkError(e)
}
})?;
let status = response.status();
if !status.is_success() {
return Err(LlmError::ApiError(format!("OpenAI API error: {}", status)));
}
let stream = response
.bytes_stream()
.filter_map(|result| async move {
match result {
Ok(bytes) => {
let text = String::from_utf8_lossy(&bytes);
let text = text.trim();
if text.is_empty() {
return None;
}
for line in text.lines() {
if let Some(data) = line.strip_prefix("data: ") {
if data == "[DONE]" {
return None;
}
if let Ok(chunk) = serde_json::from_str::<OpenAiStreamChunk>(data) {
if let Some(choice) = chunk.choices.first() {
if !choice.delta.content.is_empty() {
return Some(Ok(choice.delta.content.clone()));
}
}
}
}
}
None
}
Err(e) => Some(Err(LlmError::StreamError(e.to_string()))),
}
})
.boxed();
Ok(Box::pin(stream))
}
fn name(&self) -> &str {
"OpenAI"
}
fn model(&self) -> &str {
&self.model
}
async fn is_available(&self) -> bool {
// Quick health check - attempt to verify API key
let req = serde_json::json!({
"model": &self.model,
"messages": [{"role": "user", "content": "test"}],
"max_tokens": 1
});
self.client
.post(OPENAI_API_URL)
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&req)
.send()
.await
.map(|r| r.status().is_success())
.unwrap_or(false)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_openai_new() {
let provider = OpenAiProvider::new("sk-test", "gpt-4");
assert!(provider.is_ok());
}
#[test]
fn test_openai_empty_key() {
let provider = OpenAiProvider::new("", "gpt-4");
assert!(provider.is_err());
}
#[test]
fn test_provider_name() {
let provider = OpenAiProvider::new("sk-test", "gpt-4").unwrap();
assert_eq!(provider.name(), "OpenAI");
assert_eq!(provider.model(), "gpt-4");
}
#[test]
fn test_build_request() {
let provider = OpenAiProvider::new("sk-test", "gpt-4").unwrap();
let messages = vec![Message::user("Hello")];
let options = GenerationOptions::default();
let req = provider.build_request(&messages, &options);
assert_eq!(req.model, "gpt-4");
assert_eq!(req.messages.len(), 1);
assert!(!req.stream);
}
#[test]
fn test_build_stream_request() {
let provider = OpenAiProvider::new("sk-test", "gpt-4").unwrap();
let messages = vec![Message::user("Hello")];
let options = GenerationOptions::default();
let req = provider.build_stream_request(&messages, &options);
assert!(req.stream);
}
}

View File

@ -0,0 +1,241 @@
//! Integration with RAG system for LLM context formatting
//!
//! This module provides helpers to format RAG retrieval results as useful
//! context for LLM prompts in the configuration assistant.
/// RAG retrieval result (mirrors typedialog_core::ai::rag::RetrievalResult)
#[derive(Clone)]
pub struct RetrievalResult {
/// Document ID/source identifier
pub doc_id: String,
/// Document content
pub content: String,
/// Combined relevance score (0-1)
pub combined_score: f32,
}
/// Format RAG results as contextual examples for LLM
///
/// Converts a list of RAG retrieval results into a formatted context
/// that can be included in LLM prompts to provide relevant examples.
///
/// # Example
///
/// ```
/// use typedialog_ai::llm::rag_integration::{format_rag_context, RetrievalResult};
///
/// let results = vec![
/// RetrievalResult {
/// doc_id: "example1".to_string(),
/// content: "port = 8080".to_string(),
/// combined_score: 0.92,
/// },
/// ];
///
/// let context = format_rag_context(&results);
/// println!("Context:\n{}", context);
/// ```
pub fn format_rag_context(results: &[RetrievalResult]) -> String {
if results.is_empty() {
return "No relevant examples found.".to_string();
}
let mut context = String::from("Here are relevant examples from similar configurations:\n\n");
for (idx, result) in results.iter().enumerate() {
context.push_str(&format!(
"Example {} (from {}): [Relevance: {:.1}%]\n{}\n\n",
idx + 1,
result.doc_id,
result.combined_score * 100.0,
result.content
));
}
context
}
/// Extract field value suggestions from RAG results
///
/// Analyzes RAG results to extract common values for a specific field name.
///
/// # Example
///
/// ```
/// use typedialog_ai::llm::rag_integration::{extract_field_values, RetrievalResult};
///
/// let results = vec![
/// RetrievalResult {
/// doc_id: "config1".to_string(),
/// content: "port = 8080\nhost = localhost".to_string(),
/// combined_score: 0.90,
/// },
/// ];
///
/// let values = extract_field_values(&results, "port");
/// // Would extract "8080" from the results
/// ```
pub fn extract_field_values(results: &[RetrievalResult], field_name: &str) -> Vec<String> {
let mut values = Vec::new();
for result in results {
// Simple pattern matching for field values
// Looks for patterns like "field_name = value" or "field_name: value"
let lines = result.content.lines();
for line in lines {
let trimmed = line.trim();
// Check for TOML/Nickel style: field = value
if let Some(idx) = trimmed.find('=') {
let key_part = trimmed[..idx].trim();
if key_part == field_name {
let value_part = trimmed[idx + 1..].trim();
let value = clean_value(value_part);
if !value.is_empty() && !values.contains(&value) {
values.push(value);
}
}
}
// Check for JSON/YAML style: field: value or "field": value
if let Some(idx) = trimmed.find(':') {
let key_part = trimmed[..idx].trim().trim_matches('"');
if key_part == field_name {
let value_part = trimmed[idx + 1..].trim();
let value = clean_value(value_part);
if !value.is_empty() && !values.contains(&value) {
values.push(value);
}
}
}
}
}
values
}
/// Clean extracted field values
fn clean_value(value: &str) -> String {
let value = value.trim();
// Remove trailing comma (TOML syntax)
let value = value.strip_suffix(',').unwrap_or(value);
// Remove quotes
let value = value.trim_matches(|c| c == '"' || c == '\'');
// Remove comments
if let Some(idx) = value.find('#') {
value[..idx].trim().to_string()
} else {
value.to_string()
}
}
/// Format RAG results with field-specific recommendations
///
/// Creates a structured recommendation list for a specific field
/// based on RAG retrieval results.
pub fn format_field_recommendations(
field_name: &str,
field_description: &str,
results: &[RetrievalResult],
) -> String {
let values = extract_field_values(results, field_name);
if values.is_empty() {
return format!(
"No direct examples found for '{}', but {} is an important configuration setting.",
field_name, field_description
);
}
let mut rec = format!(
"For the '{}' field ({}):\nCommon values from similar configurations:\n",
field_name, field_description
);
for (idx, value) in values.iter().enumerate() {
rec.push_str(&format!(" {}. {}\n", idx + 1, value));
}
rec
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_format_rag_context_empty() {
let results = vec![];
let context = format_rag_context(&results);
assert_eq!(context, "No relevant examples found.");
}
#[test]
fn test_format_rag_context() {
let results = vec![RetrievalResult {
doc_id: "example1".to_string(),
content: "port = 8080".to_string(),
combined_score: 0.95,
}];
let context = format_rag_context(&results);
assert!(context.contains("Example 1"));
assert!(context.contains("example1"));
assert!(context.contains("port = 8080"));
assert!(context.contains("95.0%"));
}
#[test]
fn test_extract_field_values_toml() {
let results = vec![RetrievalResult {
doc_id: "config.toml".to_string(),
content: "port = 8080\nhost = localhost".to_string(),
combined_score: 0.90,
}];
let values = extract_field_values(&results, "port");
assert!(values.contains(&"8080".to_string()));
let hosts = extract_field_values(&results, "host");
assert!(hosts.contains(&"localhost".to_string()));
}
#[test]
fn test_extract_field_values_json() {
let results = vec![RetrievalResult {
doc_id: "config.json".to_string(),
content: r#" "port": 8080,
"host": "localhost""#
.to_string(),
combined_score: 0.90,
}];
let values = extract_field_values(&results, "host");
assert!(values.contains(&"localhost".to_string()));
}
#[test]
fn test_clean_value() {
assert_eq!(clean_value("8080"), "8080");
assert_eq!(clean_value("\"localhost\""), "localhost");
assert_eq!(clean_value("true,"), "true");
assert_eq!(clean_value("value # comment"), "value");
}
#[test]
fn test_format_field_recommendations() {
let results = vec![RetrievalResult {
doc_id: "example".to_string(),
content: "port = 8080\nport = 3000".to_string(),
combined_score: 0.90,
}];
let rec = format_field_recommendations("port", "Server port number", &results);
assert!(rec.contains("Common values"));
assert!(rec.contains("8080") || rec.contains("3000"));
}
}

View File

@ -0,0 +1,244 @@
#![allow(dead_code)]
//! TypeDialog AI Configuration Assistant Microservice
//!
//! Provides intelligent configuration assistance through:
//! - Conversational dialog with LLM
//! - RAG-powered example suggestions
//! - Nickel schema understanding
//! - Configuration generation and validation
//!
//! # Running the service
//!
//! ```bash
//! # Start the HTTP server
//! typedialog-ai serve --port 3000
//!
//! # Or run in CLI interactive mode
//! typedialog-ai cli --schema schema.ncl
//! ```
mod api;
mod assistant;
mod backend;
mod cli;
mod llm;
mod storage;
mod web_ui;
use anyhow::Result;
use clap::{Parser, Subcommand};
use std::path::{Path, PathBuf};
use std::sync::Arc;
#[derive(Parser)]
#[command(name = "typedialog-ai")]
#[command(about = "AI-powered configuration assistant backend and microservice for TypeDialog", long_about = None)]
struct Cli {
#[command(subcommand)]
command: Option<Commands>,
/// AI backend configuration file (TOML)
///
/// If provided, uses this file exclusively.
/// If not provided, searches: ~/.config/typedialog/ai/{TYPEDIALOG_ENV}.toml → ~/.config/typedialog/ai/config.toml → defaults
#[arg(global = true, short = 'c', long, value_name = "FILE")]
config: Option<PathBuf>,
/// Database path
#[arg(global = true, long, default_value = ".typedialog/ai.db")]
db_path: String,
/// SurrealDB namespace
#[arg(global = true, long, default_value = "default")]
namespace: String,
/// SurrealDB database
#[arg(global = true, long, default_value = "typedialog")]
database: String,
}
#[derive(Subcommand)]
enum Commands {
/// Start the HTTP server
#[command(about = "Start the HTTP API server")]
Serve {
/// Server port
#[arg(short, long, default_value = "3000")]
port: u16,
/// Server host
#[arg(short = 'H', long, default_value = "127.0.0.1")]
host: String,
},
/// Run in CLI interactive mode
#[command(about = "Interactive CLI mode")]
Cli {
/// Nickel schema path
#[arg(short, long)]
schema: String,
/// Conversation ID (resume existing)
#[arg(long)]
conversation_id: Option<String>,
},
}
#[tokio::main]
async fn main() -> Result<()> {
// Initialize tracing
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::from_default_env()
.add_directive("typedialog_ai=debug".parse()?),
)
.init();
let cli = Cli::parse();
// Load configuration
let config_path = cli.config.as_deref();
let config = typedialog_ai::config::TypeDialogAiConfig::load_with_cli(config_path)
.map_err(|e| anyhow::anyhow!("Failed to load configuration: {}", e))?;
if let Some(path) = config_path {
println!("📋 Using config: {}", path.display());
}
match cli.command {
Some(Commands::Serve { port, host }) => {
serve_command(
&host,
port,
&cli.db_path,
&cli.namespace,
&cli.database,
config,
)
.await?;
}
Some(Commands::Cli {
schema,
conversation_id,
}) => {
cli_command(
&schema,
conversation_id,
&cli.db_path,
&cli.namespace,
&cli.database,
config,
)
.await?;
}
None => {
// Default to serve
serve_command(
"127.0.0.1",
3000,
&cli.db_path,
&cli.namespace,
&cli.database,
config,
)
.await?;
}
}
Ok(())
}
async fn serve_command(
host: &str,
port: u16,
db_path: &str,
namespace: &str,
database: &str,
config: typedialog_ai::config::TypeDialogAiConfig,
) -> Result<()> {
println!("🚀 Starting TypeDialog AI Service on {}:{}", host, port);
println!(
"🤖 LLM Provider: {} ({})",
config.llm.provider, config.llm.model
);
println!(
"🔍 RAG System: {}",
if config.rag.enabled {
"enabled"
} else {
"disabled"
}
);
// Initialize database
let client = Arc::new(storage::SurrealDbClient::new(db_path, namespace, database).await?);
client.initialize().await?;
println!("✅ Database initialized");
println!("📡 API listening on http://{}:{}", host, port);
println!("🌐 Web UI at http://{}:{}/ui", host, port);
// Create application state
let app_state = api::AppState {
db: client,
assistants: Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())),
start_time: std::time::Instant::now(),
};
// Create router with all routes
let app = api::create_router(app_state);
// Parse host and create socket address
let addr: std::net::SocketAddr = format!("{}:{}", host, port).parse()?;
// Create listener
let listener = tokio::net::TcpListener::bind(&addr).await?;
tracing::info!("Server listening on {}", addr);
// Run the server
axum::serve(listener, app).await?;
Ok(())
}
async fn cli_command(
schema: &str,
conversation_id: Option<String>,
db_path: &str,
namespace: &str,
database: &str,
config: typedialog_ai::config::TypeDialogAiConfig,
) -> Result<()> {
// Verify schema file exists
let schema_path = Path::new(schema);
if !schema_path.exists() {
anyhow::bail!("Schema file not found: {}", schema);
}
println!(
"🤖 LLM Provider: {} ({})",
config.llm.provider, config.llm.model
);
println!(
"🔍 RAG System: {}",
if config.rag.enabled {
"enabled"
} else {
"disabled"
}
);
// Initialize database
let client = Arc::new(storage::SurrealDbClient::new(db_path, namespace, database).await?);
client.initialize().await?;
// Create interactive session
let mut session = cli::InteractiveSession::new(schema_path, client, conversation_id).await?;
// Run interactive conversation loop
session.run().await?;
Ok(())
}

View File

@ -0,0 +1,388 @@
//! SurrealDB client and connection management
//!
//! Abstracts SurrealDB operations for the AI configuration assistant.
//! Supports both local and remote connections via HTTP.
use crate::storage::models::*;
use anyhow::Result;
use uuid::Uuid;
/// SurrealDB client for AI configuration assistant
///
/// Manages connections to SurrealDB (local or remote) with support for:
/// - Graph relationships between conversations, schemas, fields
/// - Full-text search on messages
/// - Reusable RAG results
/// - Schema validation with SurrealQL contracts
pub struct SurrealDbClient {
/// Connection endpoint (e.g., "http://localhost:8000", "memory://", "file://./db.db")
endpoint: String,
/// SurrealDB namespace
namespace: String,
/// SurrealDB database name
database: String,
}
impl SurrealDbClient {
/// Create new SurrealDB client
///
/// # Arguments
/// * `endpoint` - Connection endpoint
/// - `"memory://"` - In-memory (development, data lost on restart)
/// - `"file://./typedialog.db"` - Local RocksDB file (development)
/// - `"http://localhost:8000"` - Local HTTP server
/// - `"https://surreal.example.com"` - Remote HTTP endpoint (production)
/// * `namespace` - SurrealDB namespace (e.g., "default")
/// * `database` - SurrealDB database name (e.g., "typedialog")
///
/// # Errors
/// Returns error if endpoint format is invalid
pub async fn new(endpoint: &str, namespace: &str, database: &str) -> Result<Self> {
// Validate endpoint
let valid = endpoint == "memory://"
|| endpoint.starts_with("file://")
|| endpoint.starts_with("http");
if !valid {
return Err(anyhow::anyhow!(
"Invalid endpoint. Use 'memory://', 'file://<path>', or 'http(s)://<host>:<port>'"
));
}
tracing::debug!(
endpoint = %endpoint,
namespace = %namespace,
database = %database,
"Creating SurrealDB client"
);
Ok(SurrealDbClient {
endpoint: endpoint.to_string(),
namespace: namespace.to_string(),
database: database.to_string(),
})
}
/// Get connection endpoint
pub fn endpoint(&self) -> &str {
&self.endpoint
}
/// Get namespace
pub fn namespace(&self) -> &str {
&self.namespace
}
/// Get database name
pub fn database(&self) -> &str {
&self.database
}
/// Initialize database schema
///
/// Creates necessary tables, indexes, and relationships:
/// - `conversations` - Main conversation records with status tracking
/// - `messages` - Individual messages with optional RAG result links
/// - `rag_results` - Reusable RAG retrieval results
/// - `schemas` - Nickel schema definitions
/// - `fields` - Extracted schema field metadata
///
/// All tables include proper indexing:
/// - Foreign key relationships for graph queries
/// - Full-text search indexes on messages
/// - Status and timestamp indexes for efficient filtering
pub async fn initialize(&self) -> Result<()> {
tracing::info!(
endpoint = %self.endpoint,
"Initializing SurrealDB schema (ns='{}', db='{}')",
self.namespace,
self.database
);
// TODO: Execute SurrealQL schema initialization via HTTP client
// This would create:
// - DEFINE TABLE conversations SCHEMAFULL;
// - DEFINE TABLE messages SCHEMAFULL;
// - DEFINE TABLE rag_results SCHEMAFULL;
// - DEFINE TABLE schemas SCHEMAFULL;
// - DEFINE TABLE fields SCHEMAFULL;
// - DEFINE INDEX indexes for efficient queries
Ok(())
}
// ============================================================================
// CONVERSATION OPERATIONS
// ============================================================================
/// Create new conversation
///
/// Initializes a new conversation for configuring a Nickel schema
pub async fn create_conversation(&self, schema_id: &str) -> Result<String> {
let id = Uuid::new_v4().to_string();
tracing::debug!(conversation_id = %id, schema_id = %schema_id, "Creating conversation");
// TODO: INSERT INTO conversations (id, schema_id, created_at, updated_at, status, collected_fields)
// VALUES (:id, :schema_id, now(), now(), 'active', {})
Ok(id)
}
/// Get conversation by ID
pub async fn get_conversation(&self, _id: &str) -> Result<Option<Conversation>> {
// TODO: SELECT * FROM conversations WHERE id = :id
Ok(None)
}
/// List conversations with filters
pub async fn list_conversations(
&self,
_filter: &ConversationFilter,
) -> Result<Vec<Conversation>> {
// TODO: Build query with filters and execute
Ok(vec![])
}
/// Update conversation status
pub async fn update_conversation_status(
&self,
_id: &str,
_status: ConversationStatus,
) -> Result<()> {
// TODO: UPDATE conversations SET status = :status, updated_at = now() WHERE id = :id
Ok(())
}
/// Update collected fields
pub async fn update_collected_fields(
&self,
_id: &str,
_fields: serde_json::Value,
) -> Result<()> {
// TODO: UPDATE conversations SET collected_fields = :fields, updated_at = now() WHERE id = :id
Ok(())
}
// ============================================================================
// MESSAGE OPERATIONS
// ============================================================================
/// Create new message in conversation
pub async fn create_message(
&self,
conversation_id: &str,
role: MessageRole,
_content: &str,
) -> Result<String> {
let id = Uuid::new_v4().to_string();
tracing::debug!(message_id = %id, conversation_id = %conversation_id, role = ?role, "Creating message");
// TODO: INSERT INTO messages (id, conversation_id, role, content, timestamp, rag_result_ids)
// VALUES (:id, :conversation_id, :role, :content, now(), [])
Ok(id)
}
/// Get message by ID
pub async fn get_message(&self, _id: &str) -> Result<Option<Message>> {
// TODO: SELECT * FROM messages WHERE id = :id
Ok(None)
}
/// List messages with filters
pub async fn list_messages(&self, _filter: &MessageFilter) -> Result<Vec<Message>> {
// TODO: Build query with filters, order by timestamp, apply limit/offset
Ok(vec![])
}
/// Search messages with full-text search
///
/// Uses SurrealDB full-text search to find messages containing specific terms
pub async fn search_messages(
&self,
conversation_id: &str,
query: &str,
) -> Result<Vec<Message>> {
tracing::debug!(conversation_id = %conversation_id, query = %query, "Searching messages");
// TODO: SELECT * FROM messages WHERE conversation_id = :conversation_id AND content CONTAINS :query
Ok(vec![])
}
/// Link RAG results to message
///
/// Associates RAG retrieval results with a message for relationship queries
pub async fn link_rag_results(
&self,
_message_id: &str,
_rag_result_ids: Vec<String>,
) -> Result<()> {
// TODO: UPDATE messages SET rag_result_ids = :rag_result_ids WHERE id = :message_id
Ok(())
}
// ============================================================================
// RAG RESULT OPERATIONS
// ============================================================================
/// Create RAG retrieval result
///
/// Stores reusable RAG results that can be linked to multiple messages
pub async fn create_rag_result(
&self,
_doc_id: &str,
_content: &str,
_combined_score: f32,
_semantic_score: f32,
_keyword_score: f32,
) -> Result<String> {
let id = Uuid::new_v4().to_string();
// TODO: INSERT INTO rag_results (...) VALUES (...)
Ok(id)
}
/// Get RAG result by ID
pub async fn get_rag_result(&self, _id: &str) -> Result<Option<RagResult>> {
// TODO: SELECT * FROM rag_results WHERE id = :id
Ok(None)
}
// ============================================================================
// SCHEMA OPERATIONS
// ============================================================================
/// Create schema record
///
/// Stores Nickel schema for reuse across conversations
pub async fn create_schema(&self, name: &str, _path: &str, _content: &str) -> Result<String> {
let id = Uuid::new_v4().to_string();
tracing::debug!(schema_id = %id, schema_name = %name, "Creating schema");
// TODO: INSERT INTO schemas (id, name, path, content, created_at, updated_at) VALUES (...)
Ok(id)
}
/// Get schema by ID
pub async fn get_schema(&self, _id: &str) -> Result<Option<Schema>> {
// TODO: SELECT * FROM schemas WHERE id = :id
Ok(None)
}
/// Get schema by name
pub async fn get_schema_by_name(&self, _name: &str) -> Result<Option<Schema>> {
// TODO: SELECT * FROM schemas WHERE name = :name LIMIT 1
Ok(None)
}
/// Get all schemas
pub async fn list_schemas(&self) -> Result<Vec<Schema>> {
// TODO: SELECT * FROM schemas ORDER BY created_at DESC
Ok(vec![])
}
// ============================================================================
// FIELD OPERATIONS
// ============================================================================
/// Create field definition
///
/// Stores extracted field metadata from Nickel schema
pub async fn create_field(
&self,
_schema_id: &str,
_name: &str,
_field_type: &str,
_description: &str,
_required: bool,
) -> Result<String> {
let id = Uuid::new_v4().to_string();
// TODO: INSERT INTO fields (id, schema_id, name, field_type, description, required, created_at)
Ok(id)
}
/// Get fields for schema
pub async fn get_schema_fields(&self, _schema_id: &str) -> Result<Vec<Field>> {
// TODO: SELECT * FROM fields WHERE schema_id = :schema_id
Ok(vec![])
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_client_creation() {
let client = SurrealDbClient::new("memory://", "default", "typedialog")
.await
.unwrap();
assert_eq!(client.endpoint(), "memory://");
assert_eq!(client.namespace(), "default");
assert_eq!(client.database(), "typedialog");
}
#[tokio::test]
async fn test_invalid_endpoint() {
let result = SurrealDbClient::new("invalid://endpoint", "default", "typedialog").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_file_endpoint() {
let client = SurrealDbClient::new("file://./test.db", "default", "typedialog")
.await
.unwrap();
assert_eq!(client.endpoint(), "file://./test.db");
}
#[tokio::test]
async fn test_http_endpoint() {
let client = SurrealDbClient::new("http://localhost:8000", "default", "typedialog")
.await
.unwrap();
assert_eq!(client.endpoint(), "http://localhost:8000");
}
#[tokio::test]
async fn test_initialize() {
let client = SurrealDbClient::new("memory://", "default", "typedialog")
.await
.unwrap();
let result = client.initialize().await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_create_conversation() {
let client = SurrealDbClient::new("memory://", "default", "typedialog")
.await
.unwrap();
let result = client.create_conversation("schema1").await;
assert!(result.is_ok());
let id = result.unwrap();
// ID should be a UUID
assert!(!id.is_empty());
assert_eq!(id.len(), 36); // UUID format: 8-4-4-4-12
}
#[tokio::test]
async fn test_create_message() {
let client = SurrealDbClient::new("memory://", "default", "typedialog")
.await
.unwrap();
let result = client
.create_message("conv1", MessageRole::User, "Hello")
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_create_schema() {
let client = SurrealDbClient::new("memory://", "default", "typedialog")
.await
.unwrap();
let result = client.create_schema("config", "schema.ncl", "{}").await;
assert!(result.is_ok());
}
}

View File

@ -0,0 +1,18 @@
//! SurrealDB storage layer for TypeDialog AI service
//!
//! Provides persistent storage with:
//! - Separate tables for conversations, messages, RAG results, schemas, fields
//! - Graph relationships between entities
//! - Full-text search on messages
//! - Schema validation with SurrealQL contracts
//! - Reusable messages and RAG results
pub mod client;
pub mod models;
pub use client::SurrealDbClient;
#[allow(unused_imports)]
pub use models::{
Conversation, ConversationFilter, ConversationStatus, Field, Message, MessageFilter,
MessageRole, RagResult, Schema,
};

View File

@ -0,0 +1,254 @@
//! Data models for SurrealDB storage
//!
//! Defines the schema for all SurrealDB tables with support for:
//! - Graph relationships between conversations, schemas, fields
//! - Full-text search on messages
//! - Reusable messages and RAG results
//! - Schema validation with SurrealQL contracts
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
/// Conversation record in SurrealDB
///
/// Represents a user conversation session with a particular schema
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Conversation {
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
/// Schema being configured
pub schema_id: String,
/// Creation timestamp
pub created_at: DateTime<Utc>,
/// Last update timestamp
pub updated_at: DateTime<Utc>,
/// Conversation status
pub status: ConversationStatus,
/// Collected field values as JSON
#[serde(default)]
pub collected_fields: serde_json::Value,
}
/// Conversation status
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ConversationStatus {
#[serde(rename = "active")]
Active,
#[serde(rename = "completed")]
Completed,
#[serde(rename = "abandoned")]
Abandoned,
}
/// Message record in SurrealDB
///
/// Individual message in a conversation
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
/// Conversation this message belongs to
pub conversation_id: String,
/// Message role
pub role: MessageRole,
/// Message content
pub content: String,
/// Message timestamp
pub timestamp: DateTime<Utc>,
/// IDs of RAG results used for this message
#[serde(default)]
pub rag_result_ids: Vec<String>,
/// Full-text search index
#[serde(skip)]
pub search_text: String,
}
/// Message role in conversation
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MessageRole {
#[serde(rename = "user")]
User,
#[serde(rename = "assistant")]
Assistant,
}
/// RAG retrieval result
///
/// Reusable RAG results that can be linked to multiple messages
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RagResult {
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
/// Source document ID
pub doc_id: String,
/// Retrieved content
pub content: String,
/// Relevance score (0-1)
pub combined_score: f32,
/// Semantic score component
pub semantic_score: f32,
/// Keyword score component
pub keyword_score: f32,
/// Creation timestamp
pub created_at: DateTime<Utc>,
}
/// Schema record
///
/// Stores Nickel schemas for reuse across conversations
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Schema {
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
/// Schema name
pub name: String,
/// Schema file path
pub path: String,
/// Full Nickel schema content
pub content: String,
/// Creation timestamp
pub created_at: DateTime<Utc>,
/// Last update timestamp
pub updated_at: DateTime<Utc>,
}
/// Schema field definition
///
/// Individual field extracted from a Nickel schema
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Field {
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
/// Schema this field belongs to
pub schema_id: String,
/// Field name
pub name: String,
/// Field type (String, Number, Bool, etc.)
pub field_type: String,
/// Human-readable description
pub description: String,
/// Is field required?
pub required: bool,
/// Default value if any
pub default_value: Option<String>,
/// Creation timestamp
pub created_at: DateTime<Utc>,
}
/// Query filter for messages
pub struct MessageFilter {
pub conversation_id: Option<String>,
pub role: Option<MessageRole>,
pub limit: Option<u32>,
pub offset: Option<u32>,
}
impl Default for MessageFilter {
fn default() -> Self {
MessageFilter {
conversation_id: None,
role: None,
limit: Some(50),
offset: Some(0),
}
}
}
/// Query filter for conversations
pub struct ConversationFilter {
pub schema_id: Option<String>,
pub status: Option<ConversationStatus>,
pub limit: Option<u32>,
pub offset: Option<u32>,
}
impl Default for ConversationFilter {
fn default() -> Self {
ConversationFilter {
schema_id: None,
status: None,
limit: Some(50),
offset: Some(0),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use uuid::Uuid;
#[test]
fn test_conversation_creation() {
let conv = Conversation {
id: Some(Uuid::new_v4().to_string()),
schema_id: "schema1".to_string(),
created_at: Utc::now(),
updated_at: Utc::now(),
status: ConversationStatus::Active,
collected_fields: serde_json::json!({}),
};
assert_eq!(conv.status, ConversationStatus::Active);
}
#[test]
fn test_message_serialization() {
let msg = Message {
id: Some(Uuid::new_v4().to_string()),
conversation_id: "conv1".to_string(),
role: MessageRole::User,
content: "Hello".to_string(),
timestamp: Utc::now(),
rag_result_ids: vec![],
search_text: "hello".to_string(),
};
let json = serde_json::to_string(&msg).unwrap();
drop(serde_json::from_str::<Message>(&json).unwrap());
}
#[test]
fn test_rag_result_scores() {
let rag = RagResult {
id: Some(Uuid::new_v4().to_string()),
doc_id: "doc1".to_string(),
content: "content".to_string(),
combined_score: 0.85,
semantic_score: 0.90,
keyword_score: 0.80,
created_at: Utc::now(),
};
assert!(rag.combined_score >= 0.0 && rag.combined_score <= 1.0);
}
}

View File

@ -0,0 +1,24 @@
//! Web UI module for serving the interactive web interface
use axum::response::{Html, IntoResponse};
/// Serve the main HTML page
pub async fn index() -> Html<&'static str> {
Html(include_str!("static/index.html"))
}
/// Serve the CSS styles
pub async fn styles() -> impl IntoResponse {
(
[("Content-Type", "text/css")],
include_str!("static/styles.css"),
)
}
/// Serve the JavaScript application
pub async fn app() -> impl IntoResponse {
(
[("Content-Type", "application/javascript")],
include_str!("static/app.js"),
)
}

View File

@ -0,0 +1,587 @@
/**
* TypeDialog AI Configuration Assistant - Web UI
* Handles interactive configuration generation with real-time streaming
*/
// ============================================================================
// Global State
// ============================================================================
const state = {
conversationId: null,
selectedSchema: null,
ws: null,
isConnected: false,
isStreaming: false,
generatedConfig: null,
messageHistory: [],
suggestions: [],
};
// ============================================================================
// DOM References
// ============================================================================
const elements = {
// Sidebar
schemaSelect: document.getElementById('schemaSelect'),
startBtn: document.getElementById('startBtn'),
copyBtn: document.getElementById('copyBtn'),
conversationId: document.getElementById('conversationId'),
requiredFields: document.getElementById('requiredFields'),
suggestBtn: document.getElementById('suggestBtn'),
generateBtn: document.getElementById('generateBtn'),
clearBtn: document.getElementById('clearBtn'),
// Chat
chatContainer: document.getElementById('chatContainer'),
messageInput: document.getElementById('messageInput'),
sendBtn: document.getElementById('sendBtn'),
// Suggestions
suggestionsSection: document.getElementById('suggestionsSection'),
suggestionsContainer: document.getElementById('suggestionsContainer'),
// Config
configSection: document.getElementById('configSection'),
formatSelect: document.getElementById('formatSelect'),
configPreview: document.getElementById('configPreview'),
downloadBtn: document.getElementById('downloadBtn'),
copyConfigBtn: document.getElementById('copyConfigBtn'),
// Status
statusText: document.getElementById('statusText'),
connectionStatus: document.getElementById('connectionStatus'),
loadingOverlay: document.getElementById('loadingOverlay'),
};
// ============================================================================
// Initialization
// ============================================================================
document.addEventListener('DOMContentLoaded', () => {
setupEventListeners();
updateUI();
});
function setupEventListeners() {
elements.schemaSelect.addEventListener('change', onSchemaChange);
elements.startBtn.addEventListener('click', onStartConversation);
elements.copyBtn.addEventListener('click', onCopyConversationId);
elements.sendBtn.addEventListener('click', onSendMessage);
elements.messageInput.addEventListener('keypress', (e) => {
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault();
onSendMessage();
}
});
elements.suggestBtn.addEventListener('click', onGetSuggestions);
elements.generateBtn.addEventListener('click', onGenerateConfig);
elements.downloadBtn.addEventListener('click', onDownloadConfig);
elements.copyConfigBtn.addEventListener('click', onCopyConfig);
elements.formatSelect.addEventListener('change', onFormatChange);
elements.clearBtn.addEventListener('click', onClearConversation);
}
// ============================================================================
// Event Handlers
// ============================================================================
function onSchemaChange() {
state.selectedSchema = elements.schemaSelect.value;
updateUI();
if (state.selectedSchema) {
updateRequiredFields(state.selectedSchema);
}
}
function onStartConversation() {
if (!state.selectedSchema) {
showStatus('Please select a schema', 'error');
return;
}
// Generate a conversation ID
state.conversationId = generateConversationId();
elements.conversationId.textContent = state.conversationId;
// Clear chat
state.messageHistory = [];
renderChat();
// Connect WebSocket
connectWebSocket();
updateUI();
showStatus(`Conversation started: ${state.conversationId}`);
}
function onSendMessage() {
const message = elements.messageInput.value.trim();
if (!message) return;
if (!state.ws || !state.isConnected) {
showStatus('Not connected', 'error');
return;
}
// Add user message to history
state.messageHistory.push({
role: 'user',
content: message,
timestamp: new Date(),
});
// Clear input and render
elements.messageInput.value = '';
renderChat();
// Send via WebSocket
const wsMessage = {
type: 'message',
content: message,
};
state.ws.send(JSON.stringify(wsMessage));
showStatus('Sending message...');
}
function onGetSuggestions() {
if (!state.conversationId) {
showStatus('Start a conversation first', 'error');
return;
}
showLoading(true);
showStatus('Fetching suggestions...');
// Simulate suggestions (in production, would fetch from API)
setTimeout(() => {
state.suggestions = [
{
field: 'port',
value: '8080',
reasoning: 'Standard web server port',
confidence: 0.92,
},
{
field: 'host',
value: 'localhost',
reasoning: 'Development default',
confidence: 0.85,
},
{
field: 'timeout',
value: '30',
reasoning: 'Balanced responsiveness',
confidence: 0.78,
},
];
renderSuggestions();
showLoading(false);
showStatus('Suggestions loaded');
}, 500);
}
function onGenerateConfig() {
if (!state.conversationId) {
showStatus('Start a conversation first', 'error');
return;
}
showLoading(true);
showStatus('Generating configuration...');
// Simulate config generation (in production, would call API)
setTimeout(() => {
const format = elements.formatSelect.value;
state.generatedConfig = generateConfigPreview(format);
renderConfigPreview();
showLoading(false);
showStatus('Configuration generated');
}, 1000);
}
function onDownloadConfig() {
if (!state.generatedConfig) {
showStatus('Generate a configuration first', 'error');
return;
}
const format = elements.formatSelect.value;
const filename = `config.${format}`;
const blob = new Blob([state.generatedConfig], { type: 'text/plain' });
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = filename;
a.click();
URL.revokeObjectURL(url);
showStatus('Configuration downloaded');
}
function onCopyConfig() {
if (!state.generatedConfig) {
showStatus('Generate a configuration first', 'error');
return;
}
navigator.clipboard.writeText(state.generatedConfig).then(() => {
showStatus('Configuration copied to clipboard');
});
}
function onCopyConversationId() {
if (!state.conversationId) return;
navigator.clipboard.writeText(state.conversationId).then(() => {
showStatus('Conversation ID copied');
});
}
function onFormatChange() {
if (state.generatedConfig) {
const format = elements.formatSelect.value;
state.generatedConfig = generateConfigPreview(format);
renderConfigPreview();
}
}
function onClearConversation() {
if (!confirm('Clear conversation history?')) return;
state.conversationId = null;
state.messageHistory = [];
state.generatedConfig = null;
state.suggestions = [];
renderChat();
updateUI();
showStatus('Conversation cleared');
}
// ============================================================================
// WebSocket Management
// ============================================================================
function connectWebSocket() {
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
const wsUrl = `${protocol}//${window.location.host}/ws/${state.conversationId}`;
try {
state.ws = new WebSocket(wsUrl);
state.ws.onopen = () => {
state.isConnected = true;
updateUI();
showStatus('Connected');
};
state.ws.onmessage = (event) => {
try {
const response = JSON.parse(event.data);
handleWebSocketMessage(response);
} catch (e) {
console.error('Failed to parse WebSocket message:', e);
}
};
state.ws.onerror = (error) => {
console.error('WebSocket error:', error);
showStatus('Connection error', 'error');
};
state.ws.onclose = () => {
state.isConnected = false;
updateUI();
showStatus('Disconnected', 'error');
};
} catch (e) {
showStatus('Failed to connect', 'error');
console.error(e);
}
}
function handleWebSocketMessage(response) {
switch (response.type) {
case 'start':
state.isStreaming = true;
state.messageHistory.push({
role: 'assistant',
content: '',
timestamp: new Date(),
isStreaming: true,
});
renderChat();
break;
case 'chunk':
if (state.messageHistory.length > 0) {
const lastMsg = state.messageHistory[state.messageHistory.length - 1];
if (lastMsg.role === 'assistant') {
lastMsg.content += response.content;
}
}
renderChat();
break;
case 'end':
state.isStreaming = false;
if (state.messageHistory.length > 0) {
const lastMsg = state.messageHistory[state.messageHistory.length - 1];
lastMsg.isStreaming = false;
}
renderChat();
showStatus('Message received');
break;
case 'error':
showStatus(`Error: ${response.content}`, 'error');
break;
case 'suggestions':
try {
state.suggestions = JSON.parse(response.content);
renderSuggestions();
} catch (e) {
console.error('Failed to parse suggestions:', e);
}
break;
default:
console.warn('Unknown message type:', response.type);
}
}
// ============================================================================
// Rendering Functions
// ============================================================================
function renderChat() {
if (state.messageHistory.length === 0) {
elements.chatContainer.innerHTML = `
<div class="welcome-message">
<p>👋 Welcome to TypeDialog AI Configuration Assistant</p>
<p>Type a configuration question to begin</p>
</div>
`;
return;
}
const messageHtml = state.messageHistory
.map((msg) => {
const timeStr = msg.timestamp.toLocaleTimeString();
if (msg.isStreaming) {
return `
<div class="message ${msg.role}">
<div class="message-bubble">
<div class="streaming">
<span></span>
<span></span>
<span></span>
</div>
</div>
</div>
`;
}
return `
<div class="message ${msg.role}">
<div class="message-bubble">${escapeHtml(msg.content)}</div>
<div class="message-time">${timeStr}</div>
</div>
`;
})
.join('');
elements.chatContainer.innerHTML = messageHtml;
elements.chatContainer.scrollTop = elements.chatContainer.scrollHeight;
}
function renderSuggestions() {
if (state.suggestions.length === 0) {
elements.suggestionsSection.style.display = 'none';
return;
}
elements.suggestionsSection.style.display = 'block';
const suggestionsHtml = state.suggestions
.map((suggestion) => {
const confidencePercent = (suggestion.confidence * 100).toFixed(0);
return `
<div class="suggestion-card" onclick="insertSuggestion('${escapeAttr(suggestion.field)}', '${escapeAttr(suggestion.value)}')">
<div class="suggestion-field">${escapeHtml(suggestion.field)}</div>
<div class="suggestion-value">${escapeHtml(suggestion.value)}</div>
<div class="suggestion-reasoning">${escapeHtml(suggestion.reasoning)}</div>
<div class="suggestion-confidence">
<span>${confidencePercent}%</span>
<div class="confidence-bar">
<div class="confidence-fill" style="width: ${confidencePercent}%"></div>
</div>
</div>
</div>
`;
})
.join('');
elements.suggestionsContainer.innerHTML = suggestionsHtml;
}
function renderConfigPreview() {
elements.configSection.style.display = 'block';
const codeElement = elements.configPreview.querySelector('code');
if (codeElement) {
codeElement.textContent = state.generatedConfig;
}
}
function updateUI() {
const hasConversation = state.conversationId !== null;
const hasSchema = state.selectedSchema !== null;
// Sidebar
elements.startBtn.disabled = !hasSchema;
elements.copyBtn.disabled = !hasConversation;
elements.conversationId.textContent = state.conversationId || '-';
// Chat
elements.messageInput.disabled = !state.isConnected;
elements.sendBtn.disabled = !state.isConnected || state.isStreaming;
// Actions
elements.suggestBtn.disabled = !hasConversation;
elements.generateBtn.disabled = !hasConversation;
elements.clearBtn.disabled = !hasConversation;
elements.downloadBtn.disabled = !state.generatedConfig;
elements.copyConfigBtn.disabled = !state.generatedConfig;
// Status
updateConnectionStatus();
}
function updateConnectionStatus() {
if (state.isConnected) {
elements.connectionStatus.textContent = '● Connected';
elements.connectionStatus.classList.remove('disconnected');
elements.connectionStatus.classList.add('connected');
} else {
elements.connectionStatus.textContent = '● Disconnected';
elements.connectionStatus.classList.remove('connected');
elements.connectionStatus.classList.add('disconnected');
}
}
function updateRequiredFields(schema) {
const fieldsMap = {
ServerConfig: ['port', 'host', 'timeout'],
DatabaseConfig: ['host', 'port', 'database', 'username'],
ApplicationConfig: ['name', 'version', 'debug_mode', 'log_level'],
};
const fields = fieldsMap[schema] || [];
if (fields.length === 0) {
elements.requiredFields.innerHTML =
'<li class="placeholder">No fields for this schema</li>';
return;
}
const fieldsHtml = fields
.map((field) => `<li class="required">${escapeHtml(field)}</li>`)
.join('');
elements.requiredFields.innerHTML = fieldsHtml;
}
// ============================================================================
// Utility Functions
// ============================================================================
function insertSuggestion(field, value) {
elements.messageInput.value = `Set ${field} to ${value}`;
elements.messageInput.focus();
}
function generateConversationId() {
return `conv-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`;
}
function showStatus(message, type = 'info') {
elements.statusText.textContent = message;
elements.statusText.style.color =
type === 'error' ? 'var(--color-danger)' : 'var(--color-text-light)';
setTimeout(() => {
elements.statusText.textContent = 'Ready';
elements.statusText.style.color = 'var(--color-text-light)';
}, 3000);
}
function showLoading(show) {
if (show) {
elements.loadingOverlay.style.display = 'flex';
} else {
elements.loadingOverlay.style.display = 'none';
}
}
function generateConfigPreview(format) {
const configs = {
json: JSON.stringify(
{
server: {
port: 8080,
host: 'localhost',
timeout: 30,
},
database: {
host: 'localhost',
port: 5432,
name: 'myapp',
},
logging: {
level: 'info',
format: 'json',
},
},
null,
2
),
yaml: `server:
port: 8080
host: localhost
timeout: 30
database:
host: localhost
port: 5432
name: myapp
logging:
level: info
format: json`,
toml: `[server]
port = 8080
host = "localhost"
timeout = 30
[database]
host = "localhost"
port = 5432
name = "myapp"
[logging]
level = "info"
format = "json"`,
};
return configs[format] || configs.json;
}
function escapeHtml(text) {
const div = document.createElement('div');
div.textContent = text;
return div.innerHTML;
}
function escapeAttr(text) {
return text.replace(/"/g, '&quot;').replace(/'/g, '&#39;');
}
// ============================================================================
// Export for testing
// ============================================================================
if (typeof module !== 'undefined' && module.exports) {
module.exports = { state, generateConversationId, escapeHtml };
}

View File

@ -0,0 +1,121 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>TypeDialog AI Configuration Assistant</title>
<link rel="stylesheet" href="styles.css">
</head>
<body>
<div class="container">
<!-- Header -->
<header class="header">
<div class="header-content">
<h1>🤖 TypeDialog AI Configuration Assistant</h1>
<p class="subtitle">Intelligent configuration generation with LLM assistance</p>
</div>
</header>
<!-- Main Layout -->
<div class="main-layout">
<!-- Sidebar -->
<aside class="sidebar">
<div class="sidebar-section">
<h3>Schema Selection</h3>
<select id="schemaSelect" class="schema-select">
<option value="">Select a schema...</option>
<option value="ServerConfig">Server Configuration</option>
<option value="DatabaseConfig">Database Configuration</option>
<option value="ApplicationConfig">Application Configuration</option>
</select>
<button id="startBtn" class="btn btn-primary" disabled>Start Conversation</button>
</div>
<div class="sidebar-section">
<h3>Conversation ID</h3>
<div id="conversationId" class="conversation-id">-</div>
<button id="copyBtn" class="btn btn-secondary" disabled>Copy ID</button>
</div>
<div class="sidebar-section">
<h3>Required Fields</h3>
<ul id="requiredFields" class="fields-list">
<li class="placeholder">Select a schema to see required fields</li>
</ul>
</div>
<div class="sidebar-section">
<h3>Actions</h3>
<button id="suggestBtn" class="btn btn-secondary" disabled>Get Suggestions</button>
<button id="generateBtn" class="btn btn-secondary" disabled>Generate Config</button>
<button id="clearBtn" class="btn btn-danger" disabled>Clear Conversation</button>
</div>
</aside>
<!-- Main Content -->
<main class="main-content">
<!-- Chat Section -->
<section class="chat-section">
<h2>Conversation</h2>
<div id="chatContainer" class="chat-container">
<div class="welcome-message">
<p>👋 Welcome to TypeDialog AI Configuration Assistant</p>
<p>Select a schema on the left to begin</p>
</div>
</div>
<!-- Input Area -->
<div class="input-area">
<textarea
id="messageInput"
class="message-input"
placeholder="Ask a configuration question..."
disabled
></textarea>
<button id="sendBtn" class="btn btn-primary" disabled>Send Message</button>
</div>
</section>
<!-- Suggestions Section -->
<section class="suggestions-section" id="suggestionsSection" style="display: none;">
<h3>Field Suggestions</h3>
<div id="suggestionsContainer" class="suggestions-container">
<!-- Suggestions will be populated here -->
</div>
</section>
<!-- Configuration Preview Section -->
<section class="config-section" id="configSection" style="display: none;">
<div class="config-header">
<h3>Configuration Preview</h3>
<select id="formatSelect" class="format-select">
<option value="json">JSON</option>
<option value="yaml">YAML</option>
<option value="toml">TOML</option>
</select>
</div>
<pre id="configPreview" class="config-preview"><code>// No configuration generated yet</code></pre>
<div class="config-actions">
<button id="downloadBtn" class="btn btn-secondary" disabled>Download Config</button>
<button id="copyConfigBtn" class="btn btn-secondary" disabled>Copy to Clipboard</button>
</div>
</section>
</main>
</div>
<!-- Status Bar -->
<footer class="status-bar">
<span id="statusText" class="status-text">Ready</span>
<span id="connectionStatus" class="connection-status connected">● Connected</span>
</footer>
</div>
<!-- Loading Indicator -->
<div id="loadingOverlay" class="loading-overlay" style="display: none;">
<div class="spinner"></div>
<p>Processing...</p>
</div>
<script src="app.js"></script>
</body>
</html>

View File

@ -0,0 +1,672 @@
/* ============================================================================
Global Styles
============================================================================ */
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
:root {
--color-primary: #4f46e5;
--color-primary-dark: #4338ca;
--color-secondary: #8b5cf6;
--color-success: #10b981;
--color-danger: #ef4444;
--color-warning: #f59e0b;
--color-bg: #ffffff;
--color-bg-alt: #f9fafb;
--color-bg-hover: #f3f4f6;
--color-border: #e5e7eb;
--color-text: #1f2937;
--color-text-light: #6b7280;
--color-text-lighter: #9ca3af;
--font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
--font-mono: 'Monaco', 'Menlo', 'Ubuntu Mono', 'Courier New', monospace;
--shadow-sm: 0 1px 2px 0 rgba(0, 0, 0, 0.05);
--shadow-md: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
--shadow-lg: 0 10px 15px -3px rgba(0, 0, 0, 0.1);
--border-radius: 8px;
--transition: all 0.3s ease;
}
body {
font-family: var(--font-family);
background: var(--color-bg-alt);
color: var(--color-text);
line-height: 1.6;
}
/* ============================================================================
Layout
============================================================================ */
.container {
height: 100vh;
display: flex;
flex-direction: column;
}
.header {
background: linear-gradient(135deg, var(--color-primary) 0%, var(--color-secondary) 100%);
color: white;
padding: 1.5rem;
box-shadow: var(--shadow-md);
z-index: 100;
}
.header-content h1 {
font-size: 1.875rem;
font-weight: 700;
margin-bottom: 0.25rem;
}
.subtitle {
font-size: 0.875rem;
opacity: 0.9;
}
.main-layout {
display: flex;
flex: 1;
overflow: hidden;
}
.sidebar {
width: 300px;
background: var(--color-bg);
border-right: 1px solid var(--color-border);
overflow-y: auto;
padding: 1.5rem;
gap: 1.5rem;
display: flex;
flex-direction: column;
}
.main-content {
flex: 1;
display: flex;
flex-direction: column;
overflow: hidden;
padding: 1.5rem;
gap: 1rem;
}
.status-bar {
background: var(--color-bg);
border-top: 1px solid var(--color-border);
padding: 0.75rem 1.5rem;
display: flex;
justify-content: space-between;
align-items: center;
font-size: 0.875rem;
color: var(--color-text-light);
}
/* ============================================================================
Sidebar Components
============================================================================ */
.sidebar-section {
display: flex;
flex-direction: column;
gap: 0.75rem;
}
.sidebar-section h3 {
font-size: 0.875rem;
font-weight: 600;
color: var(--color-text-light);
text-transform: uppercase;
letter-spacing: 0.05em;
}
.schema-select,
.format-select {
padding: 0.625rem;
border: 1px solid var(--color-border);
border-radius: var(--border-radius);
font-family: var(--font-family);
font-size: 0.875rem;
background: var(--color-bg);
color: var(--color-text);
cursor: pointer;
transition: var(--transition);
}
.schema-select:hover,
.format-select:hover {
border-color: var(--color-primary);
}
.schema-select:focus,
.format-select:focus {
outline: none;
border-color: var(--color-primary);
box-shadow: 0 0 0 3px rgba(79, 70, 229, 0.1);
}
.conversation-id {
padding: 0.75rem;
background: var(--color-bg-alt);
border-radius: var(--border-radius);
font-family: var(--font-mono);
font-size: 0.75rem;
word-break: break-all;
color: var(--color-text-light);
}
.fields-list {
list-style: none;
display: flex;
flex-direction: column;
gap: 0.5rem;
}
.fields-list li {
padding: 0.5rem;
background: var(--color-bg-alt);
border-radius: var(--border-radius);
font-size: 0.875rem;
color: var(--color-text);
}
.fields-list li.placeholder {
color: var(--color-text-light);
font-style: italic;
}
.fields-list li.required::before {
content: '● ';
color: var(--color-danger);
}
/* ============================================================================
Buttons
============================================================================ */
.btn {
padding: 0.625rem 1rem;
border: none;
border-radius: var(--border-radius);
font-family: var(--font-family);
font-size: 0.875rem;
font-weight: 500;
cursor: pointer;
transition: var(--transition);
width: 100%;
text-align: center;
}
.btn:disabled {
opacity: 0.5;
cursor: not-allowed;
}
.btn-primary {
background: var(--color-primary);
color: white;
}
.btn-primary:hover:not(:disabled) {
background: var(--color-primary-dark);
box-shadow: var(--shadow-md);
}
.btn-secondary {
background: var(--color-bg-alt);
color: var(--color-text);
border: 1px solid var(--color-border);
}
.btn-secondary:hover:not(:disabled) {
background: var(--color-bg-hover);
border-color: var(--color-primary);
}
.btn-danger {
background: var(--color-danger);
color: white;
}
.btn-danger:hover:not(:disabled) {
background: #dc2626;
box-shadow: var(--shadow-md);
}
/* ============================================================================
Chat Section
============================================================================ */
.chat-section {
display: flex;
flex-direction: column;
gap: 1rem;
flex: 1;
min-height: 0;
}
.chat-section h2 {
font-size: 1.125rem;
margin-bottom: 0.5rem;
}
.chat-container {
flex: 1;
background: var(--color-bg);
border: 1px solid var(--color-border);
border-radius: var(--border-radius);
padding: 1rem;
overflow-y: auto;
display: flex;
flex-direction: column;
gap: 1rem;
}
.welcome-message {
display: flex;
flex-direction: column;
gap: 0.5rem;
text-align: center;
color: var(--color-text-light);
padding: 2rem;
justify-content: center;
}
.welcome-message p:first-child {
font-size: 1.5rem;
}
.message {
display: flex;
gap: 0.75rem;
margin-bottom: 0.5rem;
}
.message.user {
justify-content: flex-end;
}
.message.assistant {
justify-content: flex-start;
}
.message-bubble {
max-width: 70%;
padding: 0.75rem 1rem;
border-radius: var(--border-radius);
word-wrap: break-word;
line-height: 1.5;
}
.message.user .message-bubble {
background: var(--color-primary);
color: white;
}
.message.assistant .message-bubble {
background: var(--color-bg-alt);
color: var(--color-text);
border: 1px solid var(--color-border);
}
.message-time {
font-size: 0.75rem;
color: var(--color-text-lighter);
padding: 0 0.5rem;
align-self: flex-end;
}
.streaming {
display: flex;
gap: 0.25rem;
}
.streaming span {
width: 0.5rem;
height: 0.5rem;
background: var(--color-text-light);
border-radius: 50%;
animation: bounce 1.4s infinite;
}
.streaming span:nth-child(2) {
animation-delay: 0.2s;
}
.streaming span:nth-child(3) {
animation-delay: 0.4s;
}
@keyframes bounce {
0%, 80%, 100% {
opacity: 0.3;
transform: translateY(0);
}
40% {
opacity: 1;
transform: translateY(-0.5rem);
}
}
.input-area {
display: flex;
gap: 0.75rem;
}
.message-input {
flex: 1;
padding: 0.75rem;
border: 1px solid var(--color-border);
border-radius: var(--border-radius);
font-family: var(--font-family);
font-size: 0.875rem;
resize: none;
max-height: 120px;
color: var(--color-text);
}
.message-input:focus {
outline: none;
border-color: var(--color-primary);
box-shadow: 0 0 0 3px rgba(79, 70, 229, 0.1);
}
.message-input:disabled {
background: var(--color-bg-alt);
color: var(--color-text-light);
}
/* ============================================================================
Suggestions Section
============================================================================ */
.suggestions-section {
background: var(--color-bg);
border: 1px solid var(--color-border);
border-radius: var(--border-radius);
padding: 1rem;
margin-top: 1rem;
}
.suggestions-section h3 {
margin-bottom: 0.75rem;
}
.suggestions-container {
display: grid;
grid-template-columns: repeat(auto-fill, minmax(200px, 1fr));
gap: 1rem;
}
.suggestion-card {
background: var(--color-bg-alt);
border: 1px solid var(--color-border);
border-radius: var(--border-radius);
padding: 1rem;
cursor: pointer;
transition: var(--transition);
}
.suggestion-card:hover {
border-color: var(--color-primary);
box-shadow: var(--shadow-md);
transform: translateY(-2px);
}
.suggestion-field {
font-weight: 600;
color: var(--color-primary);
font-size: 0.875rem;
text-transform: uppercase;
margin-bottom: 0.5rem;
}
.suggestion-value {
font-family: var(--font-mono);
font-size: 0.875rem;
padding: 0.5rem;
background: var(--color-bg);
border-radius: 4px;
margin-bottom: 0.5rem;
word-break: break-all;
}
.suggestion-reasoning {
font-size: 0.75rem;
color: var(--color-text-light);
margin-bottom: 0.5rem;
}
.suggestion-confidence {
display: flex;
align-items: center;
font-size: 0.75rem;
color: var(--color-text-light);
gap: 0.5rem;
}
.confidence-bar {
flex: 1;
height: 4px;
background: var(--color-border);
border-radius: 2px;
overflow: hidden;
}
.confidence-fill {
height: 100%;
background: var(--color-success);
border-radius: 2px;
}
/* ============================================================================
Configuration Section
============================================================================ */
.config-section {
background: var(--color-bg);
border: 1px solid var(--color-border);
border-radius: var(--border-radius);
padding: 1rem;
margin-top: 1rem;
}
.config-header {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 1rem;
}
.config-header h3 {
margin: 0;
}
.format-select {
padding: 0.5rem;
width: 150px;
}
.config-preview {
background: var(--color-text);
color: #e0e0e0;
padding: 1rem;
border-radius: var(--border-radius);
overflow-x: auto;
font-family: var(--font-mono);
font-size: 0.875rem;
line-height: 1.5;
max-height: 300px;
margin-bottom: 1rem;
}
.config-preview code {
font-family: var(--font-mono);
}
.config-actions {
display: flex;
gap: 0.75rem;
}
.config-actions .btn {
flex: 1;
}
/* ============================================================================
Status & Connection
============================================================================ */
.status-text {
flex: 1;
}
.connection-status {
display: flex;
align-items: center;
gap: 0.5rem;
padding: 0.5rem 1rem;
border-radius: 999px;
background: var(--color-bg-alt);
font-size: 0.875rem;
}
.connection-status.connected {
color: var(--color-success);
}
.connection-status.disconnected {
color: var(--color-danger);
}
/* ============================================================================
Loading Overlay
============================================================================ */
.loading-overlay {
position: fixed;
top: 0;
left: 0;
right: 0;
bottom: 0;
background: rgba(0, 0, 0, 0.5);
display: flex;
flex-direction: column;
justify-content: center;
align-items: center;
z-index: 1000;
gap: 1rem;
}
.spinner {
width: 40px;
height: 40px;
border: 4px solid rgba(255, 255, 255, 0.3);
border-top: 4px solid white;
border-radius: 50%;
animation: spin 1s linear infinite;
}
@keyframes spin {
to {
transform: rotate(360deg);
}
}
.loading-overlay p {
color: white;
font-weight: 500;
}
/* ============================================================================
Responsive Design
============================================================================ */
@media (max-width: 1024px) {
.main-layout {
flex-direction: column;
}
.sidebar {
width: 100%;
max-height: 300px;
flex-direction: row;
padding: 1rem;
overflow-x: auto;
border-right: none;
border-bottom: 1px solid var(--color-border);
}
.sidebar-section {
min-width: 250px;
}
.main-content {
padding: 1rem;
}
.suggestions-container {
grid-template-columns: repeat(auto-fill, minmax(150px, 1fr));
}
.message-bubble {
max-width: 85%;
}
}
@media (max-width: 640px) {
.header-content h1 {
font-size: 1.5rem;
}
.sidebar {
flex-direction: column;
overflow-x: visible;
max-height: none;
border-bottom: 1px solid var(--color-border);
}
.sidebar-section {
min-width: unset;
}
.suggestions-container {
grid-template-columns: 1fr;
}
.message-bubble {
max-width: 95%;
}
.btn {
font-size: 0.875rem;
padding: 0.5rem;
}
}
/* ============================================================================
Scrollbar Styling
============================================================================ */
::-webkit-scrollbar {
width: 8px;
height: 8px;
}
::-webkit-scrollbar-track {
background: var(--color-bg-alt);
}
::-webkit-scrollbar-thumb {
background: var(--color-border);
border-radius: 4px;
}
::-webkit-scrollbar-thumb:hover {
background: var(--color-text-light);
}