chore: updates and fixes
This commit is contained in:
parent
9ca1bfb8cf
commit
64ea463b69
@ -23,13 +23,13 @@ thiserror.workspace = true
|
||||
async-trait.workspace = true
|
||||
tera = { workspace = true, optional = true }
|
||||
tempfile.workspace = true
|
||||
dirs.workspace = true # For config path resolution
|
||||
|
||||
# i18n (optional)
|
||||
fluent = { workspace = true, optional = true }
|
||||
fluent-bundle = { workspace = true, optional = true }
|
||||
unic-langid = { workspace = true, optional = true }
|
||||
sys-locale = { workspace = true, optional = true }
|
||||
dirs = { workspace = true, optional = true }
|
||||
|
||||
# Nushell integration (optional)
|
||||
nu-protocol = { workspace = true, optional = true }
|
||||
@ -57,22 +57,37 @@ futures = { workspace = true, optional = true }
|
||||
# Encryption - optional (prov-ecosystem integration)
|
||||
encrypt = { path = "../../../prov-ecosystem/crates/encrypt", optional = true }
|
||||
|
||||
# AI Backend - optional
|
||||
instant-distance = { workspace = true, optional = true }
|
||||
tantivy = { workspace = true, optional = true }
|
||||
bincode = { workspace = true, optional = true }
|
||||
serde_bytes = { workspace = true, optional = true }
|
||||
rand = { workspace = true, optional = true }
|
||||
petgraph = { workspace = true, optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
serde_json.workspace = true
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
age = "0.11"
|
||||
proptest.workspace = true
|
||||
criterion.workspace = true
|
||||
|
||||
[features]
|
||||
default = ["cli", "i18n", "templates"]
|
||||
cli = ["inquire", "dialoguer", "rpassword"]
|
||||
tui = ["ratatui", "crossterm", "atty"]
|
||||
web = ["axum", "tokio", "tower", "tower-http", "tracing", "tracing-subscriber", "futures"]
|
||||
i18n = ["fluent", "fluent-bundle", "unic-langid", "sys-locale", "dirs"]
|
||||
i18n = ["fluent", "fluent-bundle", "unic-langid", "sys-locale"]
|
||||
templates = ["tera"]
|
||||
nushell = ["nu-protocol", "nu-plugin"]
|
||||
encryption = ["encrypt"]
|
||||
ai_backend = ["instant-distance", "tantivy", "bincode", "serde_bytes", "rand", "petgraph"]
|
||||
all-backends = ["cli", "tui", "web"]
|
||||
full = ["i18n", "templates", "nushell", "encryption", "all-backends"]
|
||||
full = ["i18n", "templates", "nushell", "encryption", "ai_backend", "all-backends"]
|
||||
|
||||
[[bench]]
|
||||
name = "parsing_benchmarks"
|
||||
harness = false
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
251
crates/typedialog-core/benches/parsing_benchmarks.rs
Normal file
251
crates/typedialog-core/benches/parsing_benchmarks.rs
Normal file
@ -0,0 +1,251 @@
|
||||
use criterion::{black_box, criterion_group, criterion_main, Criterion};
|
||||
use typedialog_core::form_parser::parse_toml;
|
||||
use typedialog_core::nickel::MetadataParser;
|
||||
|
||||
fn benchmark_parse_toml_simple(c: &mut Criterion) {
|
||||
let simple_form = r#"
|
||||
[[fields]]
|
||||
name = "username"
|
||||
type = "text"
|
||||
prompt = "Enter username"
|
||||
|
||||
[[fields]]
|
||||
name = "email"
|
||||
type = "text"
|
||||
prompt = "Enter email"
|
||||
"#;
|
||||
|
||||
c.bench_function("parse_toml_simple", |b| {
|
||||
b.iter(|| {
|
||||
parse_toml(black_box(simple_form)).unwrap();
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
fn benchmark_parse_toml_complex(c: &mut Criterion) {
|
||||
let complex_form = r#"
|
||||
[[fields]]
|
||||
name = "username"
|
||||
type = "text"
|
||||
prompt = "Enter username"
|
||||
min = 3
|
||||
max = 20
|
||||
required = true
|
||||
|
||||
[[fields]]
|
||||
name = "email"
|
||||
type = "text"
|
||||
prompt = "Enter email"
|
||||
required = true
|
||||
|
||||
[[fields]]
|
||||
name = "password"
|
||||
type = "password"
|
||||
prompt = "Enter password"
|
||||
min = 8
|
||||
required = true
|
||||
|
||||
[[fields]]
|
||||
name = "role"
|
||||
type = "select"
|
||||
prompt = "Select role"
|
||||
choices = ["Admin", "User", "Guest"]
|
||||
default = "User"
|
||||
|
||||
[[fields]]
|
||||
name = "permissions"
|
||||
type = "multiselect"
|
||||
prompt = "Select permissions"
|
||||
choices = ["Read", "Write", "Execute", "Delete"]
|
||||
|
||||
[[fields]]
|
||||
name = "bio"
|
||||
type = "editor"
|
||||
prompt = "Enter bio"
|
||||
default = ""
|
||||
|
||||
[[fields]]
|
||||
name = "birthdate"
|
||||
type = "date"
|
||||
prompt = "Enter birthdate"
|
||||
|
||||
[[fields]]
|
||||
name = "active"
|
||||
type = "confirm"
|
||||
prompt = "Is active?"
|
||||
default = true
|
||||
"#;
|
||||
|
||||
c.bench_function("parse_toml_complex", |b| {
|
||||
b.iter(|| {
|
||||
parse_toml(black_box(complex_form)).unwrap();
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
fn benchmark_parse_toml_nested(c: &mut Criterion) {
|
||||
let nested_form = r#"
|
||||
[[fields]]
|
||||
name = "server.host"
|
||||
type = "text"
|
||||
prompt = "Server hostname"
|
||||
default = "localhost"
|
||||
|
||||
[[fields]]
|
||||
name = "server.port"
|
||||
type = "text"
|
||||
prompt = "Server port"
|
||||
default = "8080"
|
||||
|
||||
[[fields]]
|
||||
name = "database.host"
|
||||
type = "text"
|
||||
prompt = "Database host"
|
||||
|
||||
[[fields]]
|
||||
name = "database.port"
|
||||
type = "text"
|
||||
prompt = "Database port"
|
||||
default = "5432"
|
||||
|
||||
[[fields]]
|
||||
name = "database.name"
|
||||
type = "text"
|
||||
prompt = "Database name"
|
||||
required = true
|
||||
|
||||
[[fields]]
|
||||
name = "tls.enabled"
|
||||
type = "confirm"
|
||||
prompt = "Enable TLS?"
|
||||
default = false
|
||||
|
||||
[[fields]]
|
||||
name = "tls.cert_path"
|
||||
type = "text"
|
||||
prompt = "Certificate path"
|
||||
when = "tls.enabled == true"
|
||||
"#;
|
||||
|
||||
c.bench_function("parse_toml_nested", |b| {
|
||||
b.iter(|| {
|
||||
parse_toml(black_box(nested_form)).unwrap();
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
fn benchmark_parse_nickel_simple(c: &mut Criterion) {
|
||||
let simple_json = serde_json::json!({
|
||||
"username": "alice",
|
||||
"email": "alice@example.com"
|
||||
});
|
||||
|
||||
c.bench_function("parse_nickel_simple", |b| {
|
||||
b.iter(|| {
|
||||
MetadataParser::parse(black_box(simple_json.clone())).unwrap();
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
fn benchmark_parse_nickel_complex(c: &mut Criterion) {
|
||||
let complex_json = serde_json::json!({
|
||||
"user": {
|
||||
"name": "Alice",
|
||||
"email": "alice@example.com",
|
||||
"age": 30
|
||||
},
|
||||
"settings": {
|
||||
"theme": "dark",
|
||||
"notifications": true,
|
||||
"language": "en"
|
||||
},
|
||||
"permissions": {
|
||||
"read": true,
|
||||
"write": true,
|
||||
"delete": false
|
||||
}
|
||||
});
|
||||
|
||||
c.bench_function("parse_nickel_complex", |b| {
|
||||
b.iter(|| {
|
||||
MetadataParser::parse(black_box(complex_json.clone())).unwrap();
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
fn benchmark_parse_nickel_deeply_nested(c: &mut Criterion) {
|
||||
let nested_json = serde_json::json!({
|
||||
"server": {
|
||||
"http": {
|
||||
"host": "localhost",
|
||||
"port": 8080,
|
||||
"tls": {
|
||||
"enabled": false,
|
||||
"cert_path": "/etc/certs/server.pem",
|
||||
"key_path": "/etc/certs/server.key"
|
||||
}
|
||||
},
|
||||
"database": {
|
||||
"primary": {
|
||||
"host": "db1.example.com",
|
||||
"port": 5432,
|
||||
"name": "mydb"
|
||||
},
|
||||
"replica": {
|
||||
"host": "db2.example.com",
|
||||
"port": 5432,
|
||||
"name": "mydb"
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
c.bench_function("parse_nickel_deeply_nested", |b| {
|
||||
b.iter(|| {
|
||||
MetadataParser::parse(black_box(nested_json.clone())).unwrap();
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
fn benchmark_parse_nickel_with_metadata(c: &mut Criterion) {
|
||||
let metadata_json = serde_json::json!({
|
||||
"username": {
|
||||
"type": "String",
|
||||
"doc": "The username for authentication",
|
||||
"default": "admin",
|
||||
"optional": false
|
||||
},
|
||||
"email": {
|
||||
"type": "String",
|
||||
"doc": "Email address",
|
||||
"contract": "String | std.string.Email",
|
||||
"optional": false
|
||||
},
|
||||
"age": {
|
||||
"type": "Number",
|
||||
"doc": "User age",
|
||||
"contract": "Number | std.number.between 0 120",
|
||||
"optional": true,
|
||||
"default": 25
|
||||
}
|
||||
});
|
||||
|
||||
c.bench_function("parse_nickel_with_metadata", |b| {
|
||||
b.iter(|| {
|
||||
MetadataParser::parse(black_box(metadata_json.clone())).unwrap();
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
criterion_group!(
|
||||
parsing_benches,
|
||||
benchmark_parse_toml_simple,
|
||||
benchmark_parse_toml_complex,
|
||||
benchmark_parse_toml_nested,
|
||||
benchmark_parse_nickel_simple,
|
||||
benchmark_parse_nickel_complex,
|
||||
benchmark_parse_nickel_deeply_nested,
|
||||
benchmark_parse_nickel_with_metadata
|
||||
);
|
||||
|
||||
criterion_main!(parsing_benches);
|
||||
264
crates/typedialog-core/src/ai/embeddings.rs
Normal file
264
crates/typedialog-core/src/ai/embeddings.rs
Normal file
@ -0,0 +1,264 @@
|
||||
//! Embeddings service for text encoding
|
||||
//!
|
||||
//! Provides functionality to convert text into dense vector embeddings
|
||||
//! for semantic similarity and retrieval tasks.
|
||||
//!
|
||||
//! Uses deterministic hashing for embedding generation. For production,
|
||||
//! integrate with actual ML models like fastembed or OpenAI's API.
|
||||
|
||||
use crate::error::{ErrorWrapper, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
|
||||
/// Embedding model types supported by the system
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum EmbeddingModel {
|
||||
/// BAAI/bge-small-en-v1.5 (384 dimensions, faster)
|
||||
BgeSmallEn,
|
||||
/// BAAI/bge-base-en-v1.5 (768 dimensions, better quality)
|
||||
BgeBaseEn,
|
||||
/// Multilingual BAAI/bge-m3 (1024 dimensions)
|
||||
BgeM3,
|
||||
}
|
||||
|
||||
impl EmbeddingModel {
|
||||
/// Get the HuggingFace model identifier
|
||||
pub fn model_name(&self) -> &'static str {
|
||||
match self {
|
||||
EmbeddingModel::BgeSmallEn => "BAAI/bge-small-en-v1.5",
|
||||
EmbeddingModel::BgeBaseEn => "BAAI/bge-base-en-v1.5",
|
||||
EmbeddingModel::BgeM3 => "BAAI/bge-m3",
|
||||
}
|
||||
}
|
||||
|
||||
/// Get embedding dimension size
|
||||
pub fn dimension(&self) -> usize {
|
||||
match self {
|
||||
EmbeddingModel::BgeSmallEn => 384,
|
||||
EmbeddingModel::BgeBaseEn => 768,
|
||||
EmbeddingModel::BgeM3 => 1024,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert to string representation
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
EmbeddingModel::BgeSmallEn => "BgeSmallEn",
|
||||
EmbeddingModel::BgeBaseEn => "BgeBaseEn",
|
||||
EmbeddingModel::BgeM3 => "BgeM3",
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse from string representation
|
||||
pub fn parse(s: &str) -> Self {
|
||||
match s {
|
||||
"BgeBaseEn" => EmbeddingModel::BgeBaseEn,
|
||||
"BgeM3" => EmbeddingModel::BgeM3,
|
||||
_ => EmbeddingModel::BgeSmallEn, // Default fallback
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Embedding vector with metadata
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Embedding {
|
||||
/// The embedding vector (dense float array)
|
||||
pub vector: Vec<f32>,
|
||||
/// Original text that was embedded
|
||||
pub text: String,
|
||||
/// Model used to generate this embedding
|
||||
pub model: EmbeddingModel,
|
||||
}
|
||||
|
||||
impl Embedding {
|
||||
/// Calculate cosine similarity between this embedding and another
|
||||
pub fn cosine_similarity(&self, other: &Embedding) -> f32 {
|
||||
if self.vector.is_empty() || other.vector.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let dot: f32 = self
|
||||
.vector
|
||||
.iter()
|
||||
.zip(&other.vector)
|
||||
.map(|(a, b)| a * b)
|
||||
.sum();
|
||||
|
||||
let norm_a: f32 = self.vector.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = other.vector.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
|
||||
if norm_a == 0.0 || norm_b == 0.0 {
|
||||
0.0
|
||||
} else {
|
||||
dot / (norm_a * norm_b)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Embeddings service for encoding text using deterministic hashing
|
||||
///
|
||||
/// Uses a hash-based approach for generating embeddings. For production,
|
||||
/// replace with actual ML model integration (fastembed, OpenAI, etc).
|
||||
#[cfg(feature = "ai_backend")]
|
||||
pub struct EmbeddingsService {
|
||||
model_type: EmbeddingModel,
|
||||
}
|
||||
|
||||
#[cfg(feature = "ai_backend")]
|
||||
impl EmbeddingsService {
|
||||
/// Create a new embeddings service with specified model
|
||||
pub fn new(model_type: EmbeddingModel) -> Result<Self> {
|
||||
Ok(EmbeddingsService { model_type })
|
||||
}
|
||||
|
||||
/// Embed a single text using deterministic hashing
|
||||
pub fn embed(&self, text: &str) -> Result<Embedding> {
|
||||
if text.is_empty() {
|
||||
return Err(ErrorWrapper::validation_failed("Cannot embed empty text"));
|
||||
}
|
||||
|
||||
let vector = Self::hash_to_embedding(text, self.model_type.dimension());
|
||||
|
||||
Ok(Embedding {
|
||||
vector,
|
||||
text: text.to_string(),
|
||||
model: self.model_type,
|
||||
})
|
||||
}
|
||||
|
||||
/// Embed multiple texts
|
||||
pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Embedding>> {
|
||||
if texts.is_empty() {
|
||||
return Err(ErrorWrapper::validation_failed("Cannot embed empty batch"));
|
||||
}
|
||||
|
||||
Ok(texts
|
||||
.iter()
|
||||
.map(|text| {
|
||||
let vector = Self::hash_to_embedding(text, self.model_type.dimension());
|
||||
Embedding {
|
||||
vector,
|
||||
text: text.to_string(),
|
||||
model: self.model_type,
|
||||
}
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Get the model type
|
||||
pub fn model_type(&self) -> EmbeddingModel {
|
||||
self.model_type
|
||||
}
|
||||
|
||||
/// Get embedding dimension
|
||||
pub fn dimension(&self) -> usize {
|
||||
self.model_type.dimension()
|
||||
}
|
||||
|
||||
/// Generate embedding vector from text hash
|
||||
fn hash_to_embedding(text: &str, dim: usize) -> Vec<f32> {
|
||||
let mut vector = vec![0.0; dim];
|
||||
|
||||
// Use each word in the text to generate different hash values
|
||||
for (i, word) in text.split_whitespace().enumerate() {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
word.hash(&mut hasher);
|
||||
let hash = hasher.finish();
|
||||
|
||||
// Distribute hash bits across the vector
|
||||
for (j, vec_elem) in vector.iter_mut().enumerate() {
|
||||
let bit_index = (i * 64 + j) % 64;
|
||||
let bit = (hash >> bit_index) & 1;
|
||||
*vec_elem += if bit == 1 { 1.0 } else { -1.0 };
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize vector
|
||||
let norm: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm > 0.0 {
|
||||
vector.iter_mut().for_each(|x| *x /= norm);
|
||||
}
|
||||
|
||||
vector
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "ai_backend"))]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_embedding_model_names() {
|
||||
assert_eq!(
|
||||
EmbeddingModel::BgeSmallEn.model_name(),
|
||||
"BAAI/bge-small-en-v1.5"
|
||||
);
|
||||
assert_eq!(
|
||||
EmbeddingModel::BgeBaseEn.model_name(),
|
||||
"BAAI/bge-base-en-v1.5"
|
||||
);
|
||||
assert_eq!(EmbeddingModel::BgeM3.model_name(), "BAAI/bge-m3");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_dimensions() {
|
||||
assert_eq!(EmbeddingModel::BgeSmallEn.dimension(), 384);
|
||||
assert_eq!(EmbeddingModel::BgeBaseEn.dimension(), 768);
|
||||
assert_eq!(EmbeddingModel::BgeM3.dimension(), 1024);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_similarity() {
|
||||
let emb1 = Embedding {
|
||||
vector: vec![1.0, 0.0, 0.0],
|
||||
text: "test1".to_string(),
|
||||
model: EmbeddingModel::BgeSmallEn,
|
||||
};
|
||||
|
||||
let emb2 = Embedding {
|
||||
vector: vec![1.0, 0.0, 0.0],
|
||||
text: "test2".to_string(),
|
||||
model: EmbeddingModel::BgeSmallEn,
|
||||
};
|
||||
|
||||
let similarity = emb1.cosine_similarity(&emb2);
|
||||
assert!((similarity - 1.0).abs() < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_similarity_orthogonal() {
|
||||
let emb1 = Embedding {
|
||||
vector: vec![1.0, 0.0, 0.0],
|
||||
text: "test1".to_string(),
|
||||
model: EmbeddingModel::BgeSmallEn,
|
||||
};
|
||||
|
||||
let emb2 = Embedding {
|
||||
vector: vec![0.0, 1.0, 0.0],
|
||||
text: "test2".to_string(),
|
||||
model: EmbeddingModel::BgeSmallEn,
|
||||
};
|
||||
|
||||
let similarity = emb1.cosine_similarity(&emb2);
|
||||
assert!(similarity.abs() < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_similarity_opposite() {
|
||||
let emb1 = Embedding {
|
||||
vector: vec![1.0, 0.0, 0.0],
|
||||
text: "test1".to_string(),
|
||||
model: EmbeddingModel::BgeSmallEn,
|
||||
};
|
||||
|
||||
let emb2 = Embedding {
|
||||
vector: vec![-1.0, 0.0, 0.0],
|
||||
text: "test2".to_string(),
|
||||
model: EmbeddingModel::BgeSmallEn,
|
||||
};
|
||||
|
||||
let similarity = emb1.cosine_similarity(&emb2);
|
||||
assert!((similarity + 1.0).abs() < 0.001);
|
||||
}
|
||||
}
|
||||
264
crates/typedialog-core/src/ai/indexer.rs
Normal file
264
crates/typedialog-core/src/ai/indexer.rs
Normal file
@ -0,0 +1,264 @@
|
||||
//! Full-text indexing with tantivy for keyword search
|
||||
//!
|
||||
//! Provides inverted index-based full-text search capabilities
|
||||
//! complementing semantic search with keyword matching.
|
||||
|
||||
use crate::error::{ErrorWrapper, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Full-text search result
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SearchResult {
|
||||
/// Document identifier
|
||||
pub doc_id: String,
|
||||
/// Document content
|
||||
pub content: String,
|
||||
/// Relevance score
|
||||
pub score: f32,
|
||||
/// Metadata
|
||||
pub metadata: HashMap<String, String>,
|
||||
}
|
||||
|
||||
/// Document for indexing
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Document {
|
||||
/// Unique document identifier
|
||||
pub id: String,
|
||||
/// Main content to index
|
||||
pub content: String,
|
||||
/// Additional fields
|
||||
pub fields: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl Document {
|
||||
/// Create a new document
|
||||
pub fn new(id: String, content: String) -> Self {
|
||||
Document {
|
||||
id,
|
||||
content,
|
||||
fields: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a field to the document
|
||||
pub fn with_field(mut self, key: String, value: String) -> Self {
|
||||
self.fields.insert(key, value);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Full-text indexer
|
||||
#[cfg(feature = "ai_backend")]
|
||||
#[derive(Serialize, Deserialize, Default)]
|
||||
pub struct FullTextIndexer {
|
||||
documents: HashMap<String, Document>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "ai_backend")]
|
||||
impl FullTextIndexer {
|
||||
/// Create a new indexer
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Index a document
|
||||
pub fn index(&mut self, doc: Document) -> Result<()> {
|
||||
if doc.id.is_empty() {
|
||||
return Err(ErrorWrapper::validation_failed(
|
||||
"Document ID cannot be empty",
|
||||
));
|
||||
}
|
||||
if doc.content.is_empty() {
|
||||
return Err(ErrorWrapper::validation_failed(
|
||||
"Document content cannot be empty",
|
||||
));
|
||||
}
|
||||
|
||||
self.documents.insert(doc.id.clone(), doc);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Search documents by keyword
|
||||
pub fn search(&self, query: &str, limit: usize) -> Result<Vec<SearchResult>> {
|
||||
if query.is_empty() {
|
||||
return Err(ErrorWrapper::validation_failed("Query cannot be empty"));
|
||||
}
|
||||
|
||||
let query_lower = query.to_lowercase();
|
||||
let query_terms: Vec<_> = query_lower
|
||||
.split_whitespace()
|
||||
.filter(|t| !t.is_empty())
|
||||
.collect();
|
||||
|
||||
if query_terms.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
// Simple word matching and scoring
|
||||
let mut results: Vec<SearchResult> = self
|
||||
.documents
|
||||
.values()
|
||||
.filter_map(|doc| {
|
||||
let content_lower = doc.content.to_lowercase();
|
||||
let mut score = 0.0;
|
||||
let mut matches = 0;
|
||||
|
||||
// Count term matches
|
||||
for term in &query_terms {
|
||||
let count = content_lower.matches(term).count();
|
||||
if count > 0 {
|
||||
matches += 1;
|
||||
score += count as f32 * 10.0; // Boost matching terms
|
||||
}
|
||||
}
|
||||
|
||||
// Only return results with matches
|
||||
if matches > 0 {
|
||||
Some(SearchResult {
|
||||
doc_id: doc.id.clone(),
|
||||
content: doc.content.clone(),
|
||||
score,
|
||||
metadata: doc.fields.clone(),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Sort by relevance (descending)
|
||||
results.sort_by(|a, b| {
|
||||
b.score
|
||||
.partial_cmp(&a.score)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
// Return limited results
|
||||
Ok(results.into_iter().take(limit).collect())
|
||||
}
|
||||
|
||||
/// Remove a document
|
||||
pub fn remove(&mut self, id: &str) -> Option<Document> {
|
||||
self.documents.remove(id)
|
||||
}
|
||||
|
||||
/// Get document count
|
||||
pub fn doc_count(&self) -> usize {
|
||||
self.documents.len()
|
||||
}
|
||||
|
||||
/// Clear all documents
|
||||
pub fn clear(&mut self) {
|
||||
self.documents.clear();
|
||||
}
|
||||
|
||||
/// Get a document by ID
|
||||
pub fn get(&self, id: &str) -> Option<Document> {
|
||||
self.documents.get(id).cloned()
|
||||
}
|
||||
|
||||
/// Serialize indexer for persistence
|
||||
pub fn serialize(&self) -> Result<Vec<u8>> {
|
||||
bincode::serialize(self).map_err(|e| {
|
||||
ErrorWrapper::validation_failed(format!("Failed to serialize indexer: {}", e))
|
||||
})
|
||||
}
|
||||
|
||||
/// Deserialize indexer from persistence
|
||||
pub fn deserialize(data: &[u8]) -> Result<Self> {
|
||||
bincode::deserialize(data).map_err(|e| {
|
||||
ErrorWrapper::validation_failed(format!("Failed to deserialize indexer: {}", e))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "ai_backend"))]
|
||||
pub struct FullTextIndexer;
|
||||
|
||||
#[cfg(all(test, feature = "ai_backend"))]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_indexer_index() {
|
||||
let mut indexer = FullTextIndexer::new();
|
||||
let doc = Document::new("doc1".to_string(), "hello world".to_string());
|
||||
assert!(indexer.index(doc).is_ok());
|
||||
assert_eq!(indexer.doc_count(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_indexer_empty_id() {
|
||||
let mut indexer = FullTextIndexer::new();
|
||||
let doc = Document::new(String::new(), "content".to_string());
|
||||
assert!(indexer.index(doc).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_indexer_empty_content() {
|
||||
let mut indexer = FullTextIndexer::new();
|
||||
let doc = Document::new("doc1".to_string(), String::new());
|
||||
assert!(indexer.index(doc).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_indexer_search() {
|
||||
let mut indexer = FullTextIndexer::new();
|
||||
indexer
|
||||
.index(Document::new("doc1".to_string(), "hello world".to_string()))
|
||||
.unwrap();
|
||||
indexer
|
||||
.index(Document::new(
|
||||
"doc2".to_string(),
|
||||
"goodbye world".to_string(),
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
let results = indexer.search("hello", 10).unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].doc_id, "doc1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_indexer_search_multiple_terms() {
|
||||
let mut indexer = FullTextIndexer::new();
|
||||
indexer
|
||||
.index(Document::new(
|
||||
"doc1".to_string(),
|
||||
"hello world example".to_string(),
|
||||
))
|
||||
.unwrap();
|
||||
indexer
|
||||
.index(Document::new(
|
||||
"doc2".to_string(),
|
||||
"goodbye world".to_string(),
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
let results = indexer.search("hello world", 10).unwrap();
|
||||
assert!(!results.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_indexer_remove() {
|
||||
let mut indexer = FullTextIndexer::new();
|
||||
let doc = Document::new("doc1".to_string(), "hello world".to_string());
|
||||
indexer.index(doc).unwrap();
|
||||
assert_eq!(indexer.doc_count(), 1);
|
||||
|
||||
let removed = indexer.remove("doc1");
|
||||
assert!(removed.is_some());
|
||||
assert_eq!(indexer.doc_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_indexer_get() {
|
||||
let mut indexer = FullTextIndexer::new();
|
||||
let doc = Document::new("doc1".to_string(), "hello world".to_string());
|
||||
indexer.index(doc).unwrap();
|
||||
|
||||
let retrieved = indexer.get("doc1");
|
||||
assert!(retrieved.is_some());
|
||||
}
|
||||
}
|
||||
251
crates/typedialog-core/src/ai/kg/entities.rs
Normal file
251
crates/typedialog-core/src/ai/kg/entities.rs
Normal file
@ -0,0 +1,251 @@
|
||||
//! Knowledge Graph Entity Types and Structures
|
||||
//!
|
||||
//! Defines entity types, relationships, and properties for the knowledge graph.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Entity types in the knowledge graph
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub enum EntityType {
|
||||
/// Person entity
|
||||
Person,
|
||||
/// Organization entity
|
||||
Organization,
|
||||
/// Location entity
|
||||
Location,
|
||||
/// Concept/Topic entity
|
||||
Concept,
|
||||
/// Event entity
|
||||
Event,
|
||||
/// Document entity
|
||||
Document,
|
||||
/// Generic/Unknown entity
|
||||
Generic,
|
||||
}
|
||||
|
||||
impl EntityType {
|
||||
/// Get string representation
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
EntityType::Person => "PERSON",
|
||||
EntityType::Organization => "ORGANIZATION",
|
||||
EntityType::Location => "LOCATION",
|
||||
EntityType::Concept => "CONCEPT",
|
||||
EntityType::Event => "EVENT",
|
||||
EntityType::Document => "DOCUMENT",
|
||||
EntityType::Generic => "GENERIC",
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse from string
|
||||
pub fn parse(s: &str) -> Self {
|
||||
match s.to_uppercase().as_str() {
|
||||
"PERSON" => EntityType::Person,
|
||||
"ORGANIZATION" => EntityType::Organization,
|
||||
"LOCATION" => EntityType::Location,
|
||||
"CONCEPT" => EntityType::Concept,
|
||||
"EVENT" => EntityType::Event,
|
||||
"DOCUMENT" => EntityType::Document,
|
||||
_ => EntityType::Generic,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Relationship types in the knowledge graph
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub enum RelationType {
|
||||
/// "is related to" relationship
|
||||
RelatedTo,
|
||||
/// "is part of" relationship
|
||||
PartOf,
|
||||
/// "contains" relationship
|
||||
Contains,
|
||||
/// "created by" relationship
|
||||
CreatedBy,
|
||||
/// "located in" relationship
|
||||
LocatedIn,
|
||||
/// "works for" relationship
|
||||
WorksFor,
|
||||
/// "participates in" relationship
|
||||
ParticipatesIn,
|
||||
/// "mentions" relationship
|
||||
Mentions,
|
||||
/// "references" relationship
|
||||
References,
|
||||
/// Custom relationship
|
||||
Custom(String),
|
||||
}
|
||||
|
||||
impl RelationType {
|
||||
/// Get string representation
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
RelationType::RelatedTo => "RELATED_TO",
|
||||
RelationType::PartOf => "PART_OF",
|
||||
RelationType::Contains => "CONTAINS",
|
||||
RelationType::CreatedBy => "CREATED_BY",
|
||||
RelationType::LocatedIn => "LOCATED_IN",
|
||||
RelationType::WorksFor => "WORKS_FOR",
|
||||
RelationType::ParticipatesIn => "PARTICIPATES_IN",
|
||||
RelationType::Mentions => "MENTIONS",
|
||||
RelationType::References => "REFERENCES",
|
||||
RelationType::Custom(_) => "CUSTOM",
|
||||
}
|
||||
}
|
||||
|
||||
/// Get custom string if applicable
|
||||
pub fn custom_name(&self) -> Option<&str> {
|
||||
if let RelationType::Custom(name) = self {
|
||||
Some(name)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Entity in the knowledge graph
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Entity {
|
||||
/// Unique identifier
|
||||
pub id: String,
|
||||
/// Entity name/label
|
||||
pub name: String,
|
||||
/// Entity type
|
||||
pub entity_type: EntityType,
|
||||
/// Description
|
||||
pub description: Option<String>,
|
||||
/// Custom properties
|
||||
pub properties: HashMap<String, String>,
|
||||
/// Source documents where this entity was found
|
||||
pub sources: Vec<String>,
|
||||
}
|
||||
|
||||
impl Entity {
|
||||
/// Create a new entity
|
||||
pub fn new(id: String, name: String, entity_type: EntityType) -> Self {
|
||||
Entity {
|
||||
id,
|
||||
name,
|
||||
entity_type,
|
||||
description: None,
|
||||
properties: HashMap::new(),
|
||||
sources: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add description
|
||||
pub fn with_description(mut self, desc: String) -> Self {
|
||||
self.description = Some(desc);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a property
|
||||
pub fn with_property(mut self, key: String, value: String) -> Self {
|
||||
self.properties.insert(key, value);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a source document
|
||||
pub fn with_source(mut self, source: String) -> Self {
|
||||
self.sources.push(source);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add multiple sources
|
||||
pub fn with_sources(mut self, sources: Vec<String>) -> Self {
|
||||
self.sources.extend(sources);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Relationship/Edge between entities
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Relationship {
|
||||
/// Source entity ID
|
||||
pub source: String,
|
||||
/// Target entity ID
|
||||
pub target: String,
|
||||
/// Relationship type
|
||||
pub rel_type: RelationType,
|
||||
/// Relationship strength/weight (0.0 - 1.0)
|
||||
pub weight: f32,
|
||||
/// Optional relationship properties
|
||||
pub properties: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl Relationship {
|
||||
/// Create a new relationship
|
||||
pub fn new(source: String, target: String, rel_type: RelationType) -> Self {
|
||||
Relationship {
|
||||
source,
|
||||
target,
|
||||
rel_type,
|
||||
weight: 1.0,
|
||||
properties: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set relationship weight
|
||||
pub fn with_weight(mut self, weight: f32) -> Self {
|
||||
self.weight = weight.clamp(0.0, 1.0);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a property
|
||||
pub fn with_property(mut self, key: String, value: String) -> Self {
|
||||
self.properties.insert(key, value);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_entity_type_as_str() {
|
||||
assert_eq!(EntityType::Person.as_str(), "PERSON");
|
||||
assert_eq!(EntityType::Organization.as_str(), "ORGANIZATION");
|
||||
assert_eq!(EntityType::Location.as_str(), "LOCATION");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_entity_type_parse() {
|
||||
assert_eq!(EntityType::parse("PERSON"), EntityType::Person);
|
||||
assert_eq!(EntityType::parse("person"), EntityType::Person);
|
||||
assert_eq!(EntityType::parse("INVALID"), EntityType::Generic);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_entity_creation() {
|
||||
let entity = Entity::new("e1".into(), "Alice".into(), EntityType::Person)
|
||||
.with_description("A software engineer".into())
|
||||
.with_property("role".into(), "engineer".into());
|
||||
|
||||
assert_eq!(entity.id, "e1");
|
||||
assert_eq!(entity.name, "Alice");
|
||||
assert_eq!(entity.entity_type, EntityType::Person);
|
||||
assert_eq!(entity.description, Some("A software engineer".into()));
|
||||
assert_eq!(entity.properties.get("role"), Some(&"engineer".into()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_relationship_creation() {
|
||||
let rel = Relationship::new("e1".into(), "e2".into(), RelationType::WorksFor)
|
||||
.with_weight(0.9)
|
||||
.with_property("since".into(), "2020".into());
|
||||
|
||||
assert_eq!(rel.source, "e1");
|
||||
assert_eq!(rel.target, "e2");
|
||||
assert_eq!(rel.weight, 0.9);
|
||||
assert_eq!(rel.properties.get("since"), Some(&"2020".into()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_relationship_type_custom() {
|
||||
let rel_type = RelationType::Custom("knows_well".into());
|
||||
assert_eq!(rel_type.as_str(), "CUSTOM");
|
||||
assert_eq!(rel_type.custom_name(), Some("knows_well"));
|
||||
}
|
||||
}
|
||||
509
crates/typedialog-core/src/ai/kg/graph.rs
Normal file
509
crates/typedialog-core/src/ai/kg/graph.rs
Normal file
@ -0,0 +1,509 @@
|
||||
//! Knowledge Graph Implementation using petgraph
|
||||
//!
|
||||
//! Provides graph construction, traversal, and relationship management
|
||||
//! using the petgraph library for efficient graph operations.
|
||||
|
||||
use crate::error::{ErrorWrapper, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[cfg(feature = "ai_backend")]
|
||||
use petgraph::graph::{DiGraph, NodeIndex};
|
||||
#[cfg(feature = "ai_backend")]
|
||||
use petgraph::Direction;
|
||||
|
||||
#[cfg(test)]
|
||||
use super::entities::RelationType;
|
||||
use super::entities::{Entity, EntityType, Relationship};
|
||||
|
||||
/// Knowledge Graph using petgraph DiGraph
|
||||
#[cfg(feature = "ai_backend")]
|
||||
pub struct KnowledgeGraph {
|
||||
graph: DiGraph<Entity, Relationship>,
|
||||
entity_index: HashMap<String, NodeIndex>,
|
||||
}
|
||||
|
||||
/// Serializable representation of a knowledge graph
|
||||
#[cfg(feature = "ai_backend")]
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct KnowledgeGraphSnapshot {
|
||||
/// All entities in the graph
|
||||
entities: Vec<(String, Entity)>,
|
||||
/// All relationships in the graph
|
||||
relationships: Vec<(String, String, Relationship)>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "ai_backend")]
|
||||
impl KnowledgeGraph {
|
||||
/// Create a new knowledge graph
|
||||
pub fn new() -> Self {
|
||||
KnowledgeGraph {
|
||||
graph: DiGraph::new(),
|
||||
entity_index: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add an entity to the graph
|
||||
pub fn add_entity(&mut self, entity: Entity) -> Result<()> {
|
||||
if entity.id.is_empty() {
|
||||
return Err(ErrorWrapper::validation_failed("Entity ID cannot be empty"));
|
||||
}
|
||||
|
||||
if self.entity_index.contains_key(&entity.id) {
|
||||
return Err(ErrorWrapper::validation_failed(format!(
|
||||
"Entity with ID '{}' already exists",
|
||||
entity.id
|
||||
)));
|
||||
}
|
||||
|
||||
let node_idx = self.graph.add_node(entity.clone());
|
||||
self.entity_index.insert(entity.id.clone(), node_idx);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Add multiple entities
|
||||
pub fn add_entities(&mut self, entities: Vec<Entity>) -> Result<()> {
|
||||
for entity in entities {
|
||||
self.add_entity(entity)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Add a relationship between entities
|
||||
pub fn add_relationship(&mut self, relationship: Relationship) -> Result<()> {
|
||||
let source_idx = self.entity_index.get(&relationship.source).ok_or_else(|| {
|
||||
ErrorWrapper::validation_failed(format!(
|
||||
"Source entity '{}' not found",
|
||||
relationship.source
|
||||
))
|
||||
})?;
|
||||
|
||||
let target_idx = self.entity_index.get(&relationship.target).ok_or_else(|| {
|
||||
ErrorWrapper::validation_failed(format!(
|
||||
"Target entity '{}' not found",
|
||||
relationship.target
|
||||
))
|
||||
})?;
|
||||
|
||||
self.graph.add_edge(*source_idx, *target_idx, relationship);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get entity by ID
|
||||
pub fn get_entity(&self, id: &str) -> Option<&Entity> {
|
||||
self.entity_index
|
||||
.get(id)
|
||||
.and_then(|idx| self.graph.node_weight(*idx))
|
||||
}
|
||||
|
||||
/// Get mutable entity by ID
|
||||
pub fn get_entity_mut(&mut self, id: &str) -> Option<&mut Entity> {
|
||||
if let Some(idx) = self.entity_index.get(id).copied() {
|
||||
self.graph.node_weight_mut(idx)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove entity and its relationships
|
||||
pub fn remove_entity(&mut self, id: &str) -> Option<Entity> {
|
||||
if let Some(idx) = self.entity_index.remove(id) {
|
||||
self.graph.remove_node(idx)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Find entities by type
|
||||
pub fn find_by_type(&self, entity_type: EntityType) -> Vec<&Entity> {
|
||||
self.graph
|
||||
.node_weights()
|
||||
.filter(|entity| entity.entity_type == entity_type)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Find entities by name (substring matching)
|
||||
pub fn find_by_name(&self, name: &str) -> Vec<&Entity> {
|
||||
let name_lower = name.to_lowercase();
|
||||
self.graph
|
||||
.node_weights()
|
||||
.filter(|entity| entity.name.to_lowercase().contains(&name_lower))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get entities connected to a given entity
|
||||
pub fn get_neighbors(&self, id: &str) -> Result<Vec<&Entity>> {
|
||||
let idx = self
|
||||
.entity_index
|
||||
.get(id)
|
||||
.ok_or_else(|| ErrorWrapper::validation_failed(format!("Entity '{}' not found", id)))?;
|
||||
|
||||
Ok(self
|
||||
.graph
|
||||
.neighbors(*idx)
|
||||
.filter_map(|neighbor_idx| self.graph.node_weight(neighbor_idx))
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Get incoming neighbors (entities pointing to this entity)
|
||||
pub fn get_incoming_neighbors(&self, id: &str) -> Result<Vec<&Entity>> {
|
||||
let idx = self
|
||||
.entity_index
|
||||
.get(id)
|
||||
.ok_or_else(|| ErrorWrapper::validation_failed(format!("Entity '{}' not found", id)))?;
|
||||
|
||||
Ok(self
|
||||
.graph
|
||||
.neighbors_directed(*idx, Direction::Incoming)
|
||||
.filter_map(|neighbor_idx| self.graph.node_weight(neighbor_idx))
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Get relationships from an entity
|
||||
pub fn get_outgoing_relationships(&self, id: &str) -> Result<Vec<&Relationship>> {
|
||||
let idx = self
|
||||
.entity_index
|
||||
.get(id)
|
||||
.ok_or_else(|| ErrorWrapper::validation_failed(format!("Entity '{}' not found", id)))?;
|
||||
|
||||
Ok(self
|
||||
.graph
|
||||
.edges_directed(*idx, Direction::Outgoing)
|
||||
.map(|edge| edge.weight())
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Get all relationships to an entity
|
||||
pub fn get_incoming_relationships(&self, id: &str) -> Result<Vec<&Relationship>> {
|
||||
let idx = self
|
||||
.entity_index
|
||||
.get(id)
|
||||
.ok_or_else(|| ErrorWrapper::validation_failed(format!("Entity '{}' not found", id)))?;
|
||||
|
||||
Ok(self
|
||||
.graph
|
||||
.edges_directed(*idx, Direction::Incoming)
|
||||
.map(|edge| edge.weight())
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Find shortest path between two entities
|
||||
pub fn find_path(&self, from_id: &str, to_id: &str) -> Result<Option<Vec<String>>> {
|
||||
use petgraph::algo::astar;
|
||||
|
||||
let from_idx = self.entity_index.get(from_id).ok_or_else(|| {
|
||||
ErrorWrapper::validation_failed(format!("Source '{}' not found", from_id))
|
||||
})?;
|
||||
|
||||
let to_idx = self.entity_index.get(to_id).ok_or_else(|| {
|
||||
ErrorWrapper::validation_failed(format!("Target '{}' not found", to_id))
|
||||
})?;
|
||||
|
||||
let result = astar(
|
||||
&self.graph,
|
||||
*from_idx,
|
||||
|finish| finish == *to_idx,
|
||||
|_| 1,
|
||||
|_| 0,
|
||||
);
|
||||
|
||||
Ok(result.map(|(_, path)| {
|
||||
path.into_iter()
|
||||
.filter_map(|idx| self.graph.node_weight(idx).map(|entity| entity.id.clone()))
|
||||
.collect()
|
||||
}))
|
||||
}
|
||||
|
||||
/// Get entity count
|
||||
pub fn entity_count(&self) -> usize {
|
||||
self.graph.node_count()
|
||||
}
|
||||
|
||||
/// Get relationship count
|
||||
pub fn relationship_count(&self) -> usize {
|
||||
self.graph.edge_count()
|
||||
}
|
||||
|
||||
/// Check if entity exists
|
||||
pub fn contains_entity(&self, id: &str) -> bool {
|
||||
self.entity_index.contains_key(id)
|
||||
}
|
||||
|
||||
/// Clear the graph
|
||||
pub fn clear(&mut self) {
|
||||
self.graph.clear();
|
||||
self.entity_index.clear();
|
||||
}
|
||||
|
||||
/// Get all entities
|
||||
pub fn all_entities(&self) -> Vec<&Entity> {
|
||||
self.graph.node_weights().collect()
|
||||
}
|
||||
|
||||
/// Get all relationships
|
||||
pub fn all_relationships(&self) -> Vec<&Relationship> {
|
||||
self.graph.edge_weights().collect()
|
||||
}
|
||||
|
||||
/// Serialize knowledge graph for persistence
|
||||
pub fn serialize(&self) -> Result<Vec<u8>> {
|
||||
// Convert to serializable snapshot format
|
||||
let entities: Vec<_> = self
|
||||
.entity_index
|
||||
.iter()
|
||||
.filter_map(|(id, &idx)| self.graph.node_weight(idx).map(|e| (id.clone(), e.clone())))
|
||||
.collect();
|
||||
|
||||
let relationships: Vec<_> = self
|
||||
.graph
|
||||
.edge_indices()
|
||||
.filter_map(|edge_idx| {
|
||||
let (source_idx, target_idx) = self.graph.edge_endpoints(edge_idx)?;
|
||||
let source_id = self
|
||||
.entity_index
|
||||
.iter()
|
||||
.find(|(_, &idx)| idx == source_idx)
|
||||
.map(|(id, _)| id.clone())?;
|
||||
let target_id = self
|
||||
.entity_index
|
||||
.iter()
|
||||
.find(|(_, &idx)| idx == target_idx)
|
||||
.map(|(id, _)| id.clone())?;
|
||||
let rel = self.graph.edge_weight(edge_idx)?.clone();
|
||||
Some((source_id, target_id, rel))
|
||||
})
|
||||
.collect();
|
||||
|
||||
let snapshot = KnowledgeGraphSnapshot {
|
||||
entities,
|
||||
relationships,
|
||||
};
|
||||
|
||||
bincode::serialize(&snapshot).map_err(|e| {
|
||||
ErrorWrapper::validation_failed(format!("Failed to serialize knowledge graph: {}", e))
|
||||
})
|
||||
}
|
||||
|
||||
/// Deserialize knowledge graph from persistence
|
||||
pub fn deserialize(data: &[u8]) -> Result<Self> {
|
||||
let snapshot: KnowledgeGraphSnapshot = bincode::deserialize(data).map_err(|e| {
|
||||
ErrorWrapper::validation_failed(format!("Failed to deserialize knowledge graph: {}", e))
|
||||
})?;
|
||||
|
||||
let mut kg = KnowledgeGraph::new();
|
||||
|
||||
// Add all entities
|
||||
for (_id, entity) in snapshot.entities {
|
||||
kg.add_entity(entity)?;
|
||||
}
|
||||
|
||||
// Add all relationships
|
||||
for (source_id, target_id, rel) in snapshot.relationships {
|
||||
// Ensure relationship's source and target match the provided IDs
|
||||
let mut rel = rel;
|
||||
rel.source = source_id;
|
||||
rel.target = target_id;
|
||||
kg.add_relationship(rel)?;
|
||||
}
|
||||
|
||||
Ok(kg)
|
||||
}
|
||||
|
||||
/// Save knowledge graph to a file
|
||||
pub fn save_to_file(&self, path: &str) -> Result<()> {
|
||||
let graph_data = self.serialize()?;
|
||||
let snapshot = super::super::persistence::KnowledgeGraphSnapshot {
|
||||
version: super::super::persistence::Persistence::current_version(),
|
||||
graph_data,
|
||||
};
|
||||
super::super::persistence::Persistence::save_binary(&snapshot, path)
|
||||
}
|
||||
|
||||
/// Load knowledge graph from a file
|
||||
pub fn load_from_file(path: &str) -> Result<Self> {
|
||||
let snapshot: super::super::persistence::KnowledgeGraphSnapshot =
|
||||
super::super::persistence::Persistence::load_binary(path)?;
|
||||
|
||||
if !super::super::persistence::Persistence::is_version_compatible(snapshot.version) {
|
||||
return Err(ErrorWrapper::validation_failed(format!(
|
||||
"Incompatible snapshot version: {}",
|
||||
snapshot.version
|
||||
)));
|
||||
}
|
||||
|
||||
Self::deserialize(&snapshot.graph_data)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "ai_backend")]
|
||||
impl Default for KnowledgeGraph {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "ai_backend"))]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn create_test_graph() -> KnowledgeGraph {
|
||||
let mut kg = KnowledgeGraph::new();
|
||||
let alice = Entity::new("e1".into(), "Alice".into(), EntityType::Person);
|
||||
let bob = Entity::new("e2".into(), "Bob".into(), EntityType::Person);
|
||||
let acme = Entity::new("e3".into(), "ACME Corp".into(), EntityType::Organization);
|
||||
|
||||
kg.add_entity(alice).unwrap();
|
||||
kg.add_entity(bob).unwrap();
|
||||
kg.add_entity(acme).unwrap();
|
||||
|
||||
kg.add_relationship(Relationship::new(
|
||||
"e1".into(),
|
||||
"e3".into(),
|
||||
RelationType::WorksFor,
|
||||
))
|
||||
.unwrap();
|
||||
kg.add_relationship(Relationship::new(
|
||||
"e2".into(),
|
||||
"e3".into(),
|
||||
RelationType::WorksFor,
|
||||
))
|
||||
.unwrap();
|
||||
kg.add_relationship(Relationship::new(
|
||||
"e1".into(),
|
||||
"e2".into(),
|
||||
RelationType::RelatedTo,
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
kg
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_entity() {
|
||||
let mut kg = KnowledgeGraph::new();
|
||||
let entity = Entity::new("e1".into(), "Alice".into(), EntityType::Person);
|
||||
assert!(kg.add_entity(entity).is_ok());
|
||||
assert_eq!(kg.entity_count(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_duplicate_entity() {
|
||||
let mut kg = KnowledgeGraph::new();
|
||||
let entity = Entity::new("e1".into(), "Alice".into(), EntityType::Person);
|
||||
kg.add_entity(entity.clone()).unwrap();
|
||||
assert!(kg.add_entity(entity).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_relationship() {
|
||||
let kg = create_test_graph();
|
||||
assert_eq!(kg.relationship_count(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_entity() {
|
||||
let kg = create_test_graph();
|
||||
let entity = kg.get_entity("e1").unwrap();
|
||||
assert_eq!(entity.name, "Alice");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_by_type() {
|
||||
let kg = create_test_graph();
|
||||
let persons = kg.find_by_type(EntityType::Person);
|
||||
assert_eq!(persons.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_by_name() {
|
||||
let kg = create_test_graph();
|
||||
let results = kg.find_by_name("Alice");
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].id, "e1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_neighbors() {
|
||||
let kg = create_test_graph();
|
||||
let neighbors = kg.get_neighbors("e1").unwrap();
|
||||
assert_eq!(neighbors.len(), 2); // Bob and ACME
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_outgoing_relationships() {
|
||||
let kg = create_test_graph();
|
||||
let rels = kg.get_outgoing_relationships("e1").unwrap();
|
||||
assert_eq!(rels.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_path() {
|
||||
let kg = create_test_graph();
|
||||
let path = kg.find_path("e1", "e2").unwrap();
|
||||
assert!(path.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_remove_entity() {
|
||||
let mut kg = create_test_graph();
|
||||
assert!(kg.remove_entity("e1").is_some());
|
||||
assert_eq!(kg.entity_count(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_kg_save_and_load() {
|
||||
use std::fs;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
let temp_file = NamedTempFile::new().unwrap();
|
||||
let path = temp_file.path().to_str().unwrap().to_string();
|
||||
|
||||
// Create and save graph
|
||||
let kg = create_test_graph();
|
||||
kg.save_to_file(&path).unwrap();
|
||||
assert!(fs::metadata(&path).is_ok());
|
||||
|
||||
// Load
|
||||
let loaded = KnowledgeGraph::load_from_file(&path).unwrap();
|
||||
assert_eq!(loaded.entity_count(), 3);
|
||||
assert_eq!(loaded.relationship_count(), 3);
|
||||
|
||||
// Verify entities
|
||||
let entity = loaded.get_entity("e1").unwrap();
|
||||
assert_eq!(entity.name, "Alice");
|
||||
|
||||
fs::remove_file(&path).ok();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_kg_save_and_load_empty() {
|
||||
use std::fs;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
let temp_file = NamedTempFile::new().unwrap();
|
||||
let path = temp_file.path().to_str().unwrap().to_string();
|
||||
|
||||
let kg = KnowledgeGraph::new();
|
||||
kg.save_to_file(&path).unwrap();
|
||||
|
||||
let loaded = KnowledgeGraph::load_from_file(&path).unwrap();
|
||||
assert_eq!(loaded.entity_count(), 0);
|
||||
assert_eq!(loaded.relationship_count(), 0);
|
||||
|
||||
fs::remove_file(&path).ok();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_kg_serialize_deserialize() {
|
||||
let kg = create_test_graph();
|
||||
|
||||
// Serialize
|
||||
let data = kg.serialize().unwrap();
|
||||
assert!(!data.is_empty());
|
||||
|
||||
// Deserialize
|
||||
let loaded = KnowledgeGraph::deserialize(&data).unwrap();
|
||||
assert_eq!(loaded.entity_count(), kg.entity_count());
|
||||
assert_eq!(loaded.relationship_count(), kg.relationship_count());
|
||||
}
|
||||
}
|
||||
322
crates/typedialog-core/src/ai/kg/integration.rs
Normal file
322
crates/typedialog-core/src/ai/kg/integration.rs
Normal file
@ -0,0 +1,322 @@
|
||||
//! RAG + Knowledge Graph Integration
|
||||
//!
|
||||
//! Combines retrieval-augmented generation with knowledge graph
|
||||
//! for contextual entity extraction and relationship building.
|
||||
|
||||
use crate::error::{ErrorWrapper, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[cfg(all(feature = "ai_backend", test))]
|
||||
use super::entities::RelationType;
|
||||
#[cfg(feature = "ai_backend")]
|
||||
use super::entities::{Entity, EntityType, Relationship};
|
||||
#[cfg(feature = "ai_backend")]
|
||||
use super::graph::KnowledgeGraph;
|
||||
#[cfg(feature = "ai_backend")]
|
||||
use crate::ai::rag::RetrievalResult;
|
||||
|
||||
/// Extracted entities from a document
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DocumentEntities {
|
||||
/// Document ID
|
||||
pub doc_id: String,
|
||||
/// Extracted entities
|
||||
pub entities: Vec<Entity>,
|
||||
/// Relationships between entities
|
||||
pub relationships: Vec<Relationship>,
|
||||
}
|
||||
|
||||
impl DocumentEntities {
|
||||
/// Create new document entities
|
||||
pub fn new(doc_id: String) -> Self {
|
||||
DocumentEntities {
|
||||
doc_id,
|
||||
entities: Vec::new(),
|
||||
relationships: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add extracted entity
|
||||
pub fn add_entity(&mut self, entity: Entity) {
|
||||
self.entities.push(entity);
|
||||
}
|
||||
|
||||
/// Add relationship
|
||||
pub fn add_relationship(&mut self, rel: Relationship) {
|
||||
self.relationships.push(rel);
|
||||
}
|
||||
}
|
||||
|
||||
/// KG-augmented retrieval result
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct KgAugmentedResult {
|
||||
/// Original retrieval result
|
||||
pub retrieval: RetrievalResult,
|
||||
/// Entities mentioned in the document
|
||||
pub entities: Vec<Entity>,
|
||||
/// Related entities from knowledge graph
|
||||
pub related_entities: Vec<Entity>,
|
||||
/// Relationships to explore
|
||||
pub relationships: Vec<Relationship>,
|
||||
}
|
||||
|
||||
/// RAG + KG Integration Manager
|
||||
#[cfg(feature = "ai_backend")]
|
||||
pub struct RagKgIntegration {
|
||||
kg: KnowledgeGraph,
|
||||
doc_entities: HashMap<String, DocumentEntities>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "ai_backend")]
|
||||
impl RagKgIntegration {
|
||||
/// Create new integration manager
|
||||
pub fn new() -> Self {
|
||||
RagKgIntegration {
|
||||
kg: KnowledgeGraph::new(),
|
||||
doc_entities: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add document with extracted entities and relationships
|
||||
pub fn add_document(&mut self, doc_entities: DocumentEntities) -> Result<()> {
|
||||
if doc_entities.doc_id.is_empty() {
|
||||
return Err(ErrorWrapper::validation_failed(
|
||||
"Document ID cannot be empty",
|
||||
));
|
||||
}
|
||||
|
||||
// Add entities to knowledge graph
|
||||
for entity in doc_entities.entities.clone() {
|
||||
let entity_with_source = entity.with_source(doc_entities.doc_id.clone());
|
||||
self.kg.add_entity(entity_with_source)?;
|
||||
}
|
||||
|
||||
// Add relationships to knowledge graph
|
||||
for rel in doc_entities.relationships.clone() {
|
||||
self.kg.add_relationship(rel)?;
|
||||
}
|
||||
|
||||
// Store document entities
|
||||
self.doc_entities
|
||||
.insert(doc_entities.doc_id.clone(), doc_entities);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Augment retrieval results with KG context
|
||||
pub fn augment_result(&self, result: RetrievalResult) -> Result<KgAugmentedResult> {
|
||||
let doc_entities = self
|
||||
.doc_entities
|
||||
.get(&result.doc_id)
|
||||
.map(|de| de.entities.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
// Find related entities in the graph
|
||||
let mut related_entities = Vec::new();
|
||||
let mut relationships = Vec::new();
|
||||
|
||||
for entity in &doc_entities {
|
||||
// Get neighbors
|
||||
if let Ok(neighbors) = self.kg.get_neighbors(&entity.id) {
|
||||
related_entities.extend(neighbors.iter().map(|e| (*e).clone()));
|
||||
}
|
||||
|
||||
// Get outgoing relationships
|
||||
if let Ok(rels) = self.kg.get_outgoing_relationships(&entity.id) {
|
||||
relationships.extend(rels.iter().map(|r| (*r).clone()));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(KgAugmentedResult {
|
||||
retrieval: result,
|
||||
entities: doc_entities,
|
||||
related_entities,
|
||||
relationships,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get entity context
|
||||
pub fn get_entity_context(&self, entity_id: &str) -> Result<EntityContext> {
|
||||
if !self.kg.contains_entity(entity_id) {
|
||||
return Err(ErrorWrapper::validation_failed(format!(
|
||||
"Entity '{}' not found",
|
||||
entity_id
|
||||
)));
|
||||
}
|
||||
|
||||
let entity = self.kg.get_entity(entity_id).unwrap().clone();
|
||||
|
||||
let incoming = self
|
||||
.kg
|
||||
.get_incoming_neighbors(entity_id)
|
||||
.unwrap_or_default()
|
||||
.iter()
|
||||
.map(|e| (*e).clone())
|
||||
.collect();
|
||||
|
||||
let outgoing = self
|
||||
.kg
|
||||
.get_neighbors(entity_id)
|
||||
.unwrap_or_default()
|
||||
.iter()
|
||||
.map(|e| (*e).clone())
|
||||
.collect();
|
||||
|
||||
Ok(EntityContext {
|
||||
entity,
|
||||
incoming_entities: incoming,
|
||||
outgoing_entities: outgoing,
|
||||
})
|
||||
}
|
||||
|
||||
/// Find entities related to a query
|
||||
pub fn find_related_entities(&self, query: &str, limit: usize) -> Vec<Entity> {
|
||||
let query_lower = query.to_lowercase();
|
||||
self.kg
|
||||
.all_entities()
|
||||
.iter()
|
||||
.filter(|e| {
|
||||
e.name.to_lowercase().contains(&query_lower)
|
||||
|| e.description
|
||||
.as_ref()
|
||||
.map(|d| d.to_lowercase().contains(&query_lower))
|
||||
.unwrap_or(false)
|
||||
})
|
||||
.take(limit)
|
||||
.map(|e| (*e).clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get knowledge graph statistics
|
||||
pub fn get_stats(&self) -> KgStats {
|
||||
let entity_types = self
|
||||
.kg
|
||||
.all_entities()
|
||||
.iter()
|
||||
.fold(HashMap::new(), |mut acc, entity| {
|
||||
*acc.entry(entity.entity_type).or_insert(0) += 1;
|
||||
acc
|
||||
});
|
||||
|
||||
let rel_types = self
|
||||
.kg
|
||||
.all_relationships()
|
||||
.iter()
|
||||
.fold(HashMap::new(), |mut acc, rel| {
|
||||
let type_str = rel.rel_type.as_str().to_string();
|
||||
*acc.entry(type_str).or_insert(0) += 1;
|
||||
acc
|
||||
});
|
||||
|
||||
KgStats {
|
||||
total_entities: self.kg.entity_count(),
|
||||
total_relationships: self.kg.relationship_count(),
|
||||
entity_types,
|
||||
relationship_types: rel_types,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get knowledge graph reference
|
||||
pub fn graph(&self) -> &KnowledgeGraph {
|
||||
&self.kg
|
||||
}
|
||||
|
||||
/// Get mutable knowledge graph reference
|
||||
pub fn graph_mut(&mut self) -> &mut KnowledgeGraph {
|
||||
&mut self.kg
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "ai_backend")]
|
||||
impl Default for RagKgIntegration {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Entity context in the knowledge graph
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EntityContext {
|
||||
/// The entity itself
|
||||
pub entity: Entity,
|
||||
/// Entities pointing to this entity
|
||||
pub incoming_entities: Vec<Entity>,
|
||||
/// Entities this entity points to
|
||||
pub outgoing_entities: Vec<Entity>,
|
||||
}
|
||||
|
||||
/// Knowledge graph statistics
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct KgStats {
|
||||
/// Total number of entities
|
||||
pub total_entities: usize,
|
||||
/// Total number of relationships
|
||||
pub total_relationships: usize,
|
||||
/// Count by entity type
|
||||
pub entity_types: HashMap<EntityType, usize>,
|
||||
/// Count by relationship type
|
||||
pub relationship_types: HashMap<String, usize>,
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "ai_backend"))]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn create_doc_entities() -> DocumentEntities {
|
||||
let mut doc = DocumentEntities::new("doc1".to_string());
|
||||
let alice = Entity::new("e1".into(), "Alice".into(), EntityType::Person)
|
||||
.with_description("Software engineer".into());
|
||||
let acme = Entity::new("e2".into(), "ACME".into(), EntityType::Organization);
|
||||
|
||||
doc.add_entity(alice);
|
||||
doc.add_entity(acme);
|
||||
doc.add_relationship(Relationship::new(
|
||||
"e1".into(),
|
||||
"e2".into(),
|
||||
RelationType::WorksFor,
|
||||
));
|
||||
|
||||
doc
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_document() {
|
||||
let mut integration = RagKgIntegration::new();
|
||||
let doc = create_doc_entities();
|
||||
assert!(integration.add_document(doc).is_ok());
|
||||
assert_eq!(integration.kg.entity_count(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_related_entities() {
|
||||
let mut integration = RagKgIntegration::new();
|
||||
let doc = create_doc_entities();
|
||||
integration.add_document(doc).unwrap();
|
||||
|
||||
let results = integration.find_related_entities("Alice", 10);
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].id, "e1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_stats() {
|
||||
let mut integration = RagKgIntegration::new();
|
||||
let doc = create_doc_entities();
|
||||
integration.add_document(doc).unwrap();
|
||||
|
||||
let stats = integration.get_stats();
|
||||
assert_eq!(stats.total_entities, 2);
|
||||
assert_eq!(stats.total_relationships, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_entity_context() {
|
||||
let mut integration = RagKgIntegration::new();
|
||||
let doc = create_doc_entities();
|
||||
integration.add_document(doc).unwrap();
|
||||
|
||||
let context = integration.get_entity_context("e1").unwrap();
|
||||
assert_eq!(context.entity.id, "e1");
|
||||
}
|
||||
}
|
||||
50
crates/typedialog-core/src/ai/kg/mod.rs
Normal file
50
crates/typedialog-core/src/ai/kg/mod.rs
Normal file
@ -0,0 +1,50 @@
|
||||
//! Knowledge Graph Module
|
||||
//!
|
||||
//! Implements a knowledge graph using petgraph for entity and relationship management,
|
||||
//! with integration into the RAG system for contextual retrieval.
|
||||
//!
|
||||
//! # Components
|
||||
//!
|
||||
//! - **entities**: Entity and relationship types
|
||||
//! - **graph**: Knowledge graph using petgraph DiGraph
|
||||
//! - **traversal**: Graph traversal and relationship inference algorithms
|
||||
//! - **integration**: RAG + KG integration for augmented retrieval
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! ```ignore
|
||||
//! use typedialog_core::ai::kg::{KnowledgeGraph, Entity, EntityType, Relationship, RelationType};
|
||||
//!
|
||||
//! let mut kg = KnowledgeGraph::new();
|
||||
//!
|
||||
//! let alice = Entity::new("e1".into(), "Alice".into(), EntityType::Person);
|
||||
//! let acme = Entity::new("e2".into(), "ACME Corp".into(), EntityType::Organization);
|
||||
//!
|
||||
//! kg.add_entity(alice)?;
|
||||
//! kg.add_entity(acme)?;
|
||||
//!
|
||||
//! kg.add_relationship(Relationship::new(
|
||||
//! "e1".into(),
|
||||
//! "e2".into(),
|
||||
//! RelationType::WorksFor
|
||||
//! ))?;
|
||||
//! ```
|
||||
|
||||
#[cfg(feature = "ai_backend")]
|
||||
pub mod entities;
|
||||
#[cfg(feature = "ai_backend")]
|
||||
pub mod graph;
|
||||
#[cfg(feature = "ai_backend")]
|
||||
pub mod integration;
|
||||
#[cfg(feature = "ai_backend")]
|
||||
pub mod traversal;
|
||||
|
||||
// Public exports when ai_backend feature is enabled
|
||||
#[cfg(feature = "ai_backend")]
|
||||
pub use entities::{Entity, EntityType, RelationType, Relationship};
|
||||
#[cfg(feature = "ai_backend")]
|
||||
pub use graph::KnowledgeGraph;
|
||||
#[cfg(feature = "ai_backend")]
|
||||
pub use integration::{EntityContext, KgAugmentedResult, KgStats, RagKgIntegration};
|
||||
#[cfg(feature = "ai_backend")]
|
||||
pub use traversal::{GraphTraversal, InferredRelationship, Path};
|
||||
358
crates/typedialog-core/src/ai/kg/traversal.rs
Normal file
358
crates/typedialog-core/src/ai/kg/traversal.rs
Normal file
@ -0,0 +1,358 @@
|
||||
//! Graph Traversal and Relationship Inference
|
||||
//!
|
||||
//! Provides algorithms for traversing the knowledge graph and inferring
|
||||
//! new relationships based on existing connections.
|
||||
|
||||
use crate::error::{ErrorWrapper, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
|
||||
#[cfg(all(test, feature = "ai_backend"))]
|
||||
use super::entities::Relationship;
|
||||
#[cfg(feature = "ai_backend")]
|
||||
use super::entities::{Entity, RelationType};
|
||||
#[cfg(feature = "ai_backend")]
|
||||
use super::graph::KnowledgeGraph;
|
||||
|
||||
/// Path in the knowledge graph
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Path {
|
||||
/// Sequence of entity IDs
|
||||
pub entities: Vec<String>,
|
||||
/// Relationships along the path
|
||||
pub relationships: Vec<RelationType>,
|
||||
/// Total path weight (sum of relationship weights)
|
||||
pub total_weight: f32,
|
||||
}
|
||||
|
||||
impl Path {
|
||||
/// Get path length
|
||||
pub fn length(&self) -> usize {
|
||||
self.entities.len()
|
||||
}
|
||||
|
||||
/// Get hop count (number of relationships)
|
||||
pub fn hops(&self) -> usize {
|
||||
self.relationships.len()
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of relationship inference
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct InferredRelationship {
|
||||
/// Source entity ID
|
||||
pub source: String,
|
||||
/// Target entity ID
|
||||
pub target: String,
|
||||
/// Inferred relationship type
|
||||
pub rel_type: RelationType,
|
||||
/// Confidence score (0.0 - 1.0)
|
||||
pub confidence: f32,
|
||||
/// The path that led to this inference
|
||||
pub path: Path,
|
||||
}
|
||||
|
||||
/// Graph traversal methods
|
||||
#[cfg(feature = "ai_backend")]
|
||||
pub struct GraphTraversal;
|
||||
|
||||
#[cfg(feature = "ai_backend")]
|
||||
impl GraphTraversal {
|
||||
/// Breadth-first search for entities
|
||||
pub fn bfs<'a>(
|
||||
kg: &'a KnowledgeGraph,
|
||||
start_id: &str,
|
||||
max_depth: usize,
|
||||
) -> Result<Vec<&'a Entity>> {
|
||||
if !kg.contains_entity(start_id) {
|
||||
return Err(ErrorWrapper::validation_failed(format!(
|
||||
"Entity '{}' not found",
|
||||
start_id
|
||||
)));
|
||||
}
|
||||
|
||||
let mut visited = std::collections::HashSet::new();
|
||||
let mut queue = VecDeque::new();
|
||||
let mut results = Vec::new();
|
||||
|
||||
queue.push_back((start_id.to_string(), 0));
|
||||
visited.insert(start_id.to_string());
|
||||
|
||||
while let Some((id, depth)) = queue.pop_front() {
|
||||
if let Some(entity) = kg.get_entity(&id) {
|
||||
results.push(entity);
|
||||
|
||||
if depth < max_depth {
|
||||
if let Ok(neighbors) = kg.get_neighbors(&id) {
|
||||
for neighbor in neighbors {
|
||||
if !visited.contains(&neighbor.id) {
|
||||
visited.insert(neighbor.id.clone());
|
||||
queue.push_back((neighbor.id.clone(), depth + 1));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Depth-first search for entities
|
||||
pub fn dfs<'a>(
|
||||
kg: &'a KnowledgeGraph,
|
||||
start_id: &str,
|
||||
max_depth: usize,
|
||||
) -> Result<Vec<&'a Entity>> {
|
||||
if !kg.contains_entity(start_id) {
|
||||
return Err(ErrorWrapper::validation_failed(format!(
|
||||
"Entity '{}' not found",
|
||||
start_id
|
||||
)));
|
||||
}
|
||||
|
||||
let mut visited = std::collections::HashSet::new();
|
||||
let mut results = Vec::new();
|
||||
|
||||
fn dfs_recursive<'a>(
|
||||
kg: &'a KnowledgeGraph,
|
||||
id: &str,
|
||||
depth: usize,
|
||||
max_depth: usize,
|
||||
visited: &mut std::collections::HashSet<String>,
|
||||
results: &mut Vec<&'a Entity>,
|
||||
) {
|
||||
if let Some(entity) = kg.get_entity(id) {
|
||||
results.push(entity);
|
||||
|
||||
if depth < max_depth {
|
||||
if let Ok(neighbors) = kg.get_neighbors(id) {
|
||||
for neighbor in neighbors {
|
||||
if !visited.contains(&neighbor.id) {
|
||||
visited.insert(neighbor.id.clone());
|
||||
dfs_recursive(
|
||||
kg,
|
||||
&neighbor.id,
|
||||
depth + 1,
|
||||
max_depth,
|
||||
visited,
|
||||
results,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
visited.insert(start_id.to_string());
|
||||
dfs_recursive(kg, start_id, 0, max_depth, &mut visited, &mut results);
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Find all entities within n hops
|
||||
pub fn find_within_distance(
|
||||
kg: &KnowledgeGraph,
|
||||
start_id: &str,
|
||||
max_hops: usize,
|
||||
) -> Result<HashMap<String, usize>> {
|
||||
if !kg.contains_entity(start_id) {
|
||||
return Err(ErrorWrapper::validation_failed(format!(
|
||||
"Entity '{}' not found",
|
||||
start_id
|
||||
)));
|
||||
}
|
||||
|
||||
let mut distances = HashMap::new();
|
||||
let mut queue = VecDeque::new();
|
||||
|
||||
queue.push_back((start_id.to_string(), 0));
|
||||
distances.insert(start_id.to_string(), 0);
|
||||
|
||||
while let Some((id, dist)) = queue.pop_front() {
|
||||
if dist < max_hops {
|
||||
if let Ok(neighbors) = kg.get_neighbors(&id) {
|
||||
for neighbor in neighbors {
|
||||
if !distances.contains_key(&neighbor.id) {
|
||||
distances.insert(neighbor.id.clone(), dist + 1);
|
||||
queue.push_back((neighbor.id.clone(), dist + 1));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(distances)
|
||||
}
|
||||
|
||||
/// Infer transitive relationships
|
||||
pub fn infer_transitive(
|
||||
kg: &KnowledgeGraph,
|
||||
start_id: &str,
|
||||
rel_type: &RelationType,
|
||||
max_hops: usize,
|
||||
) -> Result<Vec<InferredRelationship>> {
|
||||
if !kg.contains_entity(start_id) {
|
||||
return Err(ErrorWrapper::validation_failed(format!(
|
||||
"Entity '{}' not found",
|
||||
start_id
|
||||
)));
|
||||
}
|
||||
|
||||
let mut inferred = Vec::new();
|
||||
let mut visited = std::collections::HashSet::new();
|
||||
let mut queue = VecDeque::new();
|
||||
|
||||
queue.push_back((start_id.to_string(), Vec::new(), 1.0, Vec::new()));
|
||||
visited.insert(start_id.to_string());
|
||||
|
||||
while let Some((current_id, mut entity_path, confidence, mut rel_path)) = queue.pop_front()
|
||||
{
|
||||
entity_path.push(current_id.clone());
|
||||
|
||||
if let Ok(relationships) = kg.get_outgoing_relationships(¤t_id) {
|
||||
for rel in relationships {
|
||||
if rel.rel_type == *rel_type && rel_path.len() < max_hops {
|
||||
rel_path.push(rel.rel_type.clone());
|
||||
let new_confidence = confidence * rel.weight;
|
||||
|
||||
if !visited.contains(&rel.target) {
|
||||
visited.insert(rel.target.clone());
|
||||
|
||||
if entity_path.len() > 1 {
|
||||
inferred.push(InferredRelationship {
|
||||
source: start_id.to_string(),
|
||||
target: rel.target.clone(),
|
||||
rel_type: rel_type.clone(),
|
||||
confidence: new_confidence,
|
||||
path: Path {
|
||||
entities: entity_path.clone(),
|
||||
relationships: rel_path.clone(),
|
||||
total_weight: new_confidence,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
queue.push_back((
|
||||
rel.target.clone(),
|
||||
entity_path.clone(),
|
||||
new_confidence,
|
||||
rel_path.clone(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(inferred)
|
||||
}
|
||||
|
||||
/// Find common neighbors
|
||||
pub fn find_common_neighbors<'a>(
|
||||
kg: &'a KnowledgeGraph,
|
||||
id1: &str,
|
||||
id2: &str,
|
||||
) -> Result<Vec<&'a Entity>> {
|
||||
let neighbors1 = kg.get_neighbors(id1)?;
|
||||
let neighbors2 = kg.get_neighbors(id2)?;
|
||||
|
||||
let ids1: std::collections::HashSet<_> = neighbors1.iter().map(|e| e.id.as_str()).collect();
|
||||
|
||||
Ok(neighbors2
|
||||
.into_iter()
|
||||
.filter(|e| ids1.contains(e.id.as_str()))
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "ai_backend"))]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn create_test_kg() -> KnowledgeGraph {
|
||||
let mut kg = KnowledgeGraph::new();
|
||||
let alice = Entity::new(
|
||||
"e1".into(),
|
||||
"Alice".into(),
|
||||
crate::ai::kg::entities::EntityType::Person,
|
||||
);
|
||||
let bob = Entity::new(
|
||||
"e2".into(),
|
||||
"Bob".into(),
|
||||
crate::ai::kg::entities::EntityType::Person,
|
||||
);
|
||||
let charlie = Entity::new(
|
||||
"e3".into(),
|
||||
"Charlie".into(),
|
||||
crate::ai::kg::entities::EntityType::Person,
|
||||
);
|
||||
let acme = Entity::new(
|
||||
"e4".into(),
|
||||
"ACME".into(),
|
||||
crate::ai::kg::entities::EntityType::Organization,
|
||||
);
|
||||
|
||||
kg.add_entity(alice).unwrap();
|
||||
kg.add_entity(bob).unwrap();
|
||||
kg.add_entity(charlie).unwrap();
|
||||
kg.add_entity(acme).unwrap();
|
||||
|
||||
kg.add_relationship(Relationship::new(
|
||||
"e1".into(),
|
||||
"e4".into(),
|
||||
RelationType::WorksFor,
|
||||
))
|
||||
.unwrap();
|
||||
kg.add_relationship(Relationship::new(
|
||||
"e2".into(),
|
||||
"e4".into(),
|
||||
RelationType::WorksFor,
|
||||
))
|
||||
.unwrap();
|
||||
kg.add_relationship(Relationship::new(
|
||||
"e3".into(),
|
||||
"e4".into(),
|
||||
RelationType::WorksFor,
|
||||
))
|
||||
.unwrap();
|
||||
kg.add_relationship(Relationship::new(
|
||||
"e1".into(),
|
||||
"e2".into(),
|
||||
RelationType::RelatedTo,
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
kg
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bfs() {
|
||||
let kg = create_test_kg();
|
||||
let results = GraphTraversal::bfs(&kg, "e1", 2).unwrap();
|
||||
assert!(!results.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dfs() {
|
||||
let kg = create_test_kg();
|
||||
let results = GraphTraversal::dfs(&kg, "e1", 2).unwrap();
|
||||
assert!(!results.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_within_distance() {
|
||||
let kg = create_test_kg();
|
||||
let distances = GraphTraversal::find_within_distance(&kg, "e1", 2).unwrap();
|
||||
assert!(distances.contains_key("e1"));
|
||||
assert_eq!(distances.get("e1"), Some(&0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_common_neighbors() {
|
||||
let kg = create_test_kg();
|
||||
let common = GraphTraversal::find_common_neighbors(&kg, "e1", "e2").unwrap();
|
||||
assert!(!common.is_empty());
|
||||
}
|
||||
}
|
||||
57
crates/typedialog-core/src/ai/mod.rs
Normal file
57
crates/typedialog-core/src/ai/mod.rs
Normal file
@ -0,0 +1,57 @@
|
||||
//! AI Backend Module
|
||||
//!
|
||||
//! Provides embeddings, vector search, full-text indexing, and RAG capabilities
|
||||
//! for semantic and keyword-based document retrieval.
|
||||
//!
|
||||
//! # Features
|
||||
//!
|
||||
//! - **Embeddings**: Convert text to dense vectors using fastembed
|
||||
//! - **Vector Store**: HNSW-based approximate nearest neighbor search with instant-distance
|
||||
//! - **Full-Text Indexing**: BM25-like scoring with tantivy integration
|
||||
//! - **RAG System**: Hybrid retrieval combining semantic and keyword search
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! ```ignore
|
||||
//! use typedialog_core::ai::{RagSystem, RagConfig};
|
||||
//!
|
||||
//! #[tokio::main]
|
||||
//! async fn main() -> Result<()> {
|
||||
//! let config = RagConfig::default();
|
||||
//! let mut rag = RagSystem::new(config).await?;
|
||||
//!
|
||||
//! rag.add_document("doc1".to_string(), "Hello world".to_string()).await?;
|
||||
//! let results = rag.retrieve("hello").await?;
|
||||
//! Ok(())
|
||||
//! }
|
||||
//! ```
|
||||
|
||||
#[cfg(feature = "ai_backend")]
|
||||
pub mod embeddings;
|
||||
#[cfg(feature = "ai_backend")]
|
||||
pub mod indexer;
|
||||
#[cfg(feature = "ai_backend")]
|
||||
pub mod kg;
|
||||
#[cfg(feature = "ai_backend")]
|
||||
pub mod persistence;
|
||||
#[cfg(feature = "ai_backend")]
|
||||
pub mod rag;
|
||||
#[cfg(feature = "ai_backend")]
|
||||
pub mod vector_store;
|
||||
|
||||
// Public exports when ai_backend feature is enabled
|
||||
#[cfg(feature = "ai_backend")]
|
||||
pub use embeddings::{EmbeddingModel, EmbeddingsService};
|
||||
#[cfg(feature = "ai_backend")]
|
||||
pub use indexer::{Document, FullTextIndexer};
|
||||
#[cfg(feature = "ai_backend")]
|
||||
pub use kg::{
|
||||
Entity, EntityType, GraphTraversal, KnowledgeGraph, RagKgIntegration, RelationType,
|
||||
Relationship,
|
||||
};
|
||||
#[cfg(feature = "ai_backend")]
|
||||
pub use persistence::{KnowledgeGraphSnapshot, Persistence, RagSystemSnapshot};
|
||||
#[cfg(feature = "ai_backend")]
|
||||
pub use rag::{RagConfig, RagSystem, RetrievalResult, RetrievalSource};
|
||||
#[cfg(feature = "ai_backend")]
|
||||
pub use vector_store::{VectorData, VectorStore};
|
||||
256
crates/typedialog-core/src/ai/persistence.rs
Normal file
256
crates/typedialog-core/src/ai/persistence.rs
Normal file
@ -0,0 +1,256 @@
|
||||
//! Persistence and serialization for AI backend
|
||||
//!
|
||||
//! Provides save/load functionality for RAG systems and knowledge graphs
|
||||
//! using binary serialization (bincode) for efficiency and optional JSON for debugging.
|
||||
|
||||
use crate::error::{ErrorWrapper, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
|
||||
/// Schema version for forward/backward compatibility
|
||||
const PERSISTENCE_VERSION: u32 = 1;
|
||||
|
||||
/// RAG system snapshot for persistence
|
||||
///
|
||||
/// Captures the complete state of a RAG system including vectors,
|
||||
/// full-text index, and embeddings configuration.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RagSystemSnapshot {
|
||||
/// Schema version for compatibility
|
||||
pub version: u32,
|
||||
/// Serialized vector store data
|
||||
pub vector_store_data: Vec<u8>,
|
||||
/// Serialized full-text indexer data
|
||||
pub indexer_data: Vec<u8>,
|
||||
/// Embeddings model configuration
|
||||
pub embeddings_model: String,
|
||||
/// RAG configuration (weights, max_results, etc)
|
||||
pub config_data: Vec<u8>,
|
||||
}
|
||||
|
||||
/// Knowledge graph snapshot for persistence
|
||||
///
|
||||
/// Captures entities, relationships, and graph structure.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct KnowledgeGraphSnapshot {
|
||||
/// Schema version for compatibility
|
||||
pub version: u32,
|
||||
/// Serialized graph data
|
||||
pub graph_data: Vec<u8>,
|
||||
}
|
||||
|
||||
/// Persistence utilities
|
||||
pub struct Persistence;
|
||||
|
||||
impl Persistence {
|
||||
/// Save data to a binary file using bincode
|
||||
///
|
||||
/// Uses atomic writes to prevent corruption on failure.
|
||||
/// Creates a temporary file, writes to it, then renames it.
|
||||
pub fn save_binary<T: Serialize>(data: &T, path: &str) -> Result<()> {
|
||||
let serialized = bincode::serialize(data).map_err(|e| {
|
||||
ErrorWrapper::validation_failed(format!("Failed to serialize data: {}", e))
|
||||
})?;
|
||||
|
||||
// Write to temp file first (atomic write)
|
||||
let path_obj = Path::new(path);
|
||||
let parent = path_obj
|
||||
.parent()
|
||||
.ok_or_else(|| ErrorWrapper::validation_failed("Invalid path".to_string()))?;
|
||||
|
||||
// Create parent directories if needed
|
||||
if !parent.as_os_str().is_empty() && !parent.exists() {
|
||||
fs::create_dir_all(parent).map_err(|e| {
|
||||
ErrorWrapper::validation_failed(format!("Failed to create directories: {}", e))
|
||||
})?;
|
||||
}
|
||||
|
||||
// Write with atomic rename
|
||||
let temp_path = format!("{}.tmp", path);
|
||||
fs::write(&temp_path, serialized)
|
||||
.map_err(|e| ErrorWrapper::validation_failed(format!("Failed to write file: {}", e)))?;
|
||||
|
||||
fs::rename(&temp_path, path).map_err(|e| {
|
||||
// Clean up temp file on rename failure
|
||||
drop(fs::remove_file(&temp_path));
|
||||
ErrorWrapper::validation_failed(format!("Failed to finalize save: {}", e))
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load data from a binary file using bincode
|
||||
pub fn load_binary<T: for<'de> Deserialize<'de>>(path: &str) -> Result<T> {
|
||||
let data = fs::read(path)
|
||||
.map_err(|e| ErrorWrapper::validation_failed(format!("Failed to read file: {}", e)))?;
|
||||
|
||||
bincode::deserialize(&data).map_err(|e| {
|
||||
ErrorWrapper::validation_failed(format!("Failed to deserialize data: {}", e))
|
||||
})
|
||||
}
|
||||
|
||||
/// Save data to a JSON file for human readability
|
||||
///
|
||||
/// Useful for debugging and inspection, but larger than binary format.
|
||||
pub fn save_json<T: Serialize>(data: &T, path: &str) -> Result<()> {
|
||||
let json = serde_json::to_string_pretty(data).map_err(|e| {
|
||||
ErrorWrapper::validation_failed(format!("Failed to serialize to JSON: {}", e))
|
||||
})?;
|
||||
|
||||
let path_obj = Path::new(path);
|
||||
let parent = path_obj
|
||||
.parent()
|
||||
.ok_or_else(|| ErrorWrapper::validation_failed("Invalid path".to_string()))?;
|
||||
|
||||
if !parent.as_os_str().is_empty() && !parent.exists() {
|
||||
fs::create_dir_all(parent).map_err(|e| {
|
||||
ErrorWrapper::validation_failed(format!("Failed to create directories: {}", e))
|
||||
})?;
|
||||
}
|
||||
|
||||
let temp_path = format!("{}.tmp", path);
|
||||
fs::write(&temp_path, json)
|
||||
.map_err(|e| ErrorWrapper::validation_failed(format!("Failed to write file: {}", e)))?;
|
||||
|
||||
fs::rename(&temp_path, path).map_err(|e| {
|
||||
drop(fs::remove_file(&temp_path));
|
||||
ErrorWrapper::validation_failed(format!("Failed to finalize save: {}", e))
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load data from a JSON file
|
||||
pub fn load_json<T: for<'de> Deserialize<'de>>(path: &str) -> Result<T> {
|
||||
let json = fs::read_to_string(path)
|
||||
.map_err(|e| ErrorWrapper::validation_failed(format!("Failed to read file: {}", e)))?;
|
||||
|
||||
serde_json::from_str(&json).map_err(|e| {
|
||||
ErrorWrapper::validation_failed(format!("Failed to deserialize from JSON: {}", e))
|
||||
})
|
||||
}
|
||||
|
||||
/// Get current persistence version
|
||||
pub fn current_version() -> u32 {
|
||||
PERSISTENCE_VERSION
|
||||
}
|
||||
|
||||
/// Check if snapshot version is compatible
|
||||
pub fn is_version_compatible(version: u32) -> bool {
|
||||
// For now, exact version match required
|
||||
// Future: implement migration logic for version > 1
|
||||
version == PERSISTENCE_VERSION
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "ai_backend"))]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::fs;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
#[test]
|
||||
fn test_save_and_load_json() {
|
||||
let test_data = RagSystemSnapshot {
|
||||
version: PERSISTENCE_VERSION,
|
||||
vector_store_data: vec![1, 2, 3],
|
||||
indexer_data: vec![4, 5, 6],
|
||||
embeddings_model: "test_model".to_string(),
|
||||
config_data: vec![7, 8, 9],
|
||||
};
|
||||
|
||||
let temp_file = NamedTempFile::new().unwrap();
|
||||
let path = temp_file.path().to_str().unwrap().to_string();
|
||||
|
||||
// Save
|
||||
Persistence::save_json(&test_data, &path).unwrap();
|
||||
assert!(fs::metadata(&path).is_ok());
|
||||
|
||||
// Load
|
||||
let loaded: RagSystemSnapshot = Persistence::load_json(&path).unwrap();
|
||||
assert_eq!(loaded.version, test_data.version);
|
||||
assert_eq!(loaded.vector_store_data, test_data.vector_store_data);
|
||||
|
||||
fs::remove_file(&path).ok();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_save_and_load_binary() {
|
||||
let test_data = RagSystemSnapshot {
|
||||
version: PERSISTENCE_VERSION,
|
||||
vector_store_data: vec![10, 20, 30],
|
||||
indexer_data: vec![40, 50, 60],
|
||||
embeddings_model: "another_model".to_string(),
|
||||
config_data: vec![70, 80, 90],
|
||||
};
|
||||
|
||||
let temp_file = NamedTempFile::new().unwrap();
|
||||
let path = temp_file.path().to_str().unwrap().to_string();
|
||||
|
||||
// Save
|
||||
Persistence::save_binary(&test_data, &path).unwrap();
|
||||
assert!(fs::metadata(&path).is_ok());
|
||||
|
||||
// Load
|
||||
let loaded: RagSystemSnapshot = Persistence::load_binary(&path).unwrap();
|
||||
assert_eq!(loaded.version, test_data.version);
|
||||
assert_eq!(loaded.indexer_data, test_data.indexer_data);
|
||||
|
||||
fs::remove_file(&path).ok();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_atomic_write_cleanup() {
|
||||
let temp_file = NamedTempFile::new().unwrap();
|
||||
let path = temp_file.path().to_str().unwrap().to_string();
|
||||
|
||||
let test_data = RagSystemSnapshot {
|
||||
version: PERSISTENCE_VERSION,
|
||||
vector_store_data: vec![1, 2, 3],
|
||||
indexer_data: vec![4, 5, 6],
|
||||
embeddings_model: "model".to_string(),
|
||||
config_data: vec![7, 8, 9],
|
||||
};
|
||||
|
||||
Persistence::save_binary(&test_data, &path).unwrap();
|
||||
|
||||
// Check temp file doesn't exist
|
||||
let temp_path = format!("{}.tmp", path);
|
||||
assert!(fs::metadata(&temp_path).is_err());
|
||||
|
||||
fs::remove_file(&path).ok();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_version_compatibility() {
|
||||
assert!(Persistence::is_version_compatible(PERSISTENCE_VERSION));
|
||||
assert!(!Persistence::is_version_compatible(PERSISTENCE_VERSION + 1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_nonexistent_file() {
|
||||
let result: Result<RagSystemSnapshot> =
|
||||
Persistence::load_binary("/nonexistent/path/file.bin");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_kg_snapshot() {
|
||||
let kg_snap = KnowledgeGraphSnapshot {
|
||||
version: PERSISTENCE_VERSION,
|
||||
graph_data: vec![1, 2, 3, 4, 5],
|
||||
};
|
||||
|
||||
let temp_file = NamedTempFile::new().unwrap();
|
||||
let path = temp_file.path().to_str().unwrap().to_string();
|
||||
|
||||
Persistence::save_json(&kg_snap, &path).unwrap();
|
||||
let loaded: KnowledgeGraphSnapshot = Persistence::load_json(&path).unwrap();
|
||||
|
||||
assert_eq!(loaded.version, kg_snap.version);
|
||||
assert_eq!(loaded.graph_data, kg_snap.graph_data);
|
||||
|
||||
fs::remove_file(&path).ok();
|
||||
}
|
||||
}
|
||||
550
crates/typedialog-core/src/ai/rag.rs
Normal file
550
crates/typedialog-core/src/ai/rag.rs
Normal file
@ -0,0 +1,550 @@
|
||||
//! RAG (Retrieval-Augmented Generation) orchestration
|
||||
//!
|
||||
//! Combines semantic search (embeddings + vector store) with
|
||||
//! full-text search to retrieve relevant documents for LLM queries.
|
||||
|
||||
use crate::error::{ErrorWrapper, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[cfg(feature = "ai_backend")]
|
||||
use super::embeddings::{EmbeddingModel, EmbeddingsService};
|
||||
#[cfg(feature = "ai_backend")]
|
||||
use super::indexer::{Document, FullTextIndexer};
|
||||
#[cfg(feature = "ai_backend")]
|
||||
use super::vector_store::{VectorData, VectorStore};
|
||||
|
||||
/// RAG retrieval result
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RetrievalResult {
|
||||
/// Document ID
|
||||
pub doc_id: String,
|
||||
/// Document content
|
||||
pub content: String,
|
||||
/// Semantic relevance score (0-1)
|
||||
pub semantic_score: f32,
|
||||
/// Keyword relevance score (0-1)
|
||||
pub keyword_score: f32,
|
||||
/// Combined relevance score
|
||||
pub combined_score: f32,
|
||||
/// Source type (semantic or keyword)
|
||||
pub source: RetrievalSource,
|
||||
}
|
||||
|
||||
/// Source of retrieval
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum RetrievalSource {
|
||||
/// Retrieved via semantic search
|
||||
Semantic,
|
||||
/// Retrieved via keyword search
|
||||
Keyword,
|
||||
/// Retrieved via both methods
|
||||
Hybrid,
|
||||
}
|
||||
|
||||
/// RAG configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RagConfig {
|
||||
/// Weight for semantic search results (0.0-1.0)
|
||||
pub semantic_weight: f32,
|
||||
/// Weight for keyword search results (0.0-1.0)
|
||||
pub keyword_weight: f32,
|
||||
/// Maximum number of results to return
|
||||
pub max_results: usize,
|
||||
/// Minimum combined score threshold
|
||||
pub min_score: f32,
|
||||
}
|
||||
|
||||
impl Default for RagConfig {
|
||||
fn default() -> Self {
|
||||
RagConfig {
|
||||
semantic_weight: 0.6,
|
||||
keyword_weight: 0.4,
|
||||
max_results: 5,
|
||||
min_score: 0.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// RAG system combining semantic and keyword search
|
||||
#[cfg(feature = "ai_backend")]
|
||||
pub struct RagSystem {
|
||||
embeddings: EmbeddingsService,
|
||||
vector_store: VectorStore,
|
||||
indexer: FullTextIndexer,
|
||||
config: RagConfig,
|
||||
}
|
||||
|
||||
#[cfg(feature = "ai_backend")]
|
||||
impl RagSystem {
|
||||
/// Create a new RAG system
|
||||
pub fn new(config: RagConfig) -> Result<Self> {
|
||||
let embeddings = EmbeddingsService::new(EmbeddingModel::BgeSmallEn)?;
|
||||
let vector_store = VectorStore::new(config.max_results);
|
||||
let indexer = FullTextIndexer::new();
|
||||
|
||||
Ok(RagSystem {
|
||||
embeddings,
|
||||
vector_store,
|
||||
indexer,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create with custom embedding model
|
||||
pub fn with_model(model: EmbeddingModel, config: RagConfig) -> Result<Self> {
|
||||
let embeddings = EmbeddingsService::new(model)?;
|
||||
let vector_store = VectorStore::new(config.max_results);
|
||||
let indexer = FullTextIndexer::new();
|
||||
|
||||
Ok(RagSystem {
|
||||
embeddings,
|
||||
vector_store,
|
||||
indexer,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
/// Add a document to the RAG system
|
||||
pub fn add_document(&mut self, id: String, content: String) -> Result<()> {
|
||||
// Index for full-text search
|
||||
let doc = Document::new(id.clone(), content.clone());
|
||||
self.indexer.index(doc)?;
|
||||
|
||||
// Embed and add to vector store
|
||||
let embedding = self.embeddings.embed(&content)?;
|
||||
let vector_data = VectorData::new(id, embedding.vector);
|
||||
self.vector_store.insert(vector_data)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Add multiple documents
|
||||
pub fn add_documents(&mut self, docs: Vec<(String, String)>) -> Result<()> {
|
||||
for (id, content) in docs {
|
||||
self.add_document(id, content)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Add multiple documents in a single batch operation
|
||||
///
|
||||
/// More efficient than add_documents() as it:
|
||||
/// - Embeds all documents in a single batch call
|
||||
/// - Indexes all documents in the full-text indexer
|
||||
/// - Invalidates vector store cache only once
|
||||
///
|
||||
/// This is significantly faster for large document sets (100+).
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `docs` - Vec of (id, content) tuples to add
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns error if any document fails validation or embedding
|
||||
pub fn add_documents_batch(&mut self, docs: Vec<(String, String)>) -> Result<()> {
|
||||
if docs.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Index all documents in full-text indexer
|
||||
let mut document_map = HashMap::new();
|
||||
for (id, content) in &docs {
|
||||
let doc = Document::new(id.clone(), content.clone());
|
||||
self.indexer.index(doc)?;
|
||||
document_map.insert(id.clone(), content.clone());
|
||||
}
|
||||
|
||||
// Embed all documents in a single batch call
|
||||
let contents: Vec<&str> = docs.iter().map(|(_, content)| content.as_str()).collect();
|
||||
let embeddings = self.embeddings.embed_batch(&contents)?;
|
||||
|
||||
// Prepare vector data for batch insertion
|
||||
let vector_data: Vec<VectorData> = docs
|
||||
.iter()
|
||||
.zip(embeddings.iter())
|
||||
.map(|((id, _), embedding)| VectorData::new(id.clone(), embedding.vector.clone()))
|
||||
.collect();
|
||||
|
||||
// Insert all vectors in a single batch operation
|
||||
self.vector_store.insert_batch(vector_data)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Retrieve relevant documents for a query
|
||||
pub fn retrieve(&mut self, query: &str) -> Result<Vec<RetrievalResult>> {
|
||||
if query.is_empty() {
|
||||
return Err(ErrorWrapper::validation_failed("Query cannot be empty"));
|
||||
}
|
||||
|
||||
// Semantic search
|
||||
let embedding = self.embeddings.embed(query)?;
|
||||
let semantic_results = self
|
||||
.vector_store
|
||||
.search(&embedding.vector, self.config.max_results)?;
|
||||
|
||||
// Keyword search
|
||||
let keyword_results = self.indexer.search(query, self.config.max_results)?;
|
||||
|
||||
// Combine results
|
||||
let mut combined: HashMap<String, RetrievalResult> = HashMap::new();
|
||||
|
||||
// Add semantic results
|
||||
for result in semantic_results.iter() {
|
||||
let normalized_score = 1.0 / (1.0 + result.distance); // Convert distance to similarity
|
||||
let combined_score = normalized_score * self.config.semantic_weight;
|
||||
|
||||
combined.insert(
|
||||
result.id.clone(),
|
||||
RetrievalResult {
|
||||
doc_id: result.id.clone(),
|
||||
content: String::new(), // Will be filled from indexer
|
||||
semantic_score: normalized_score,
|
||||
keyword_score: 0.0,
|
||||
combined_score,
|
||||
source: RetrievalSource::Semantic,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// Add keyword results
|
||||
for result in keyword_results.iter() {
|
||||
let keyword_score = result.score.min(1.0); // Normalize to 0-1
|
||||
let keyword_weighted = keyword_score * self.config.keyword_weight;
|
||||
|
||||
combined
|
||||
.entry(result.doc_id.clone())
|
||||
.and_modify(|r| {
|
||||
r.keyword_score = keyword_score;
|
||||
r.combined_score = (r.semantic_score * self.config.semantic_weight
|
||||
+ keyword_score * self.config.keyword_weight)
|
||||
/ (self.config.semantic_weight + self.config.keyword_weight);
|
||||
r.source = RetrievalSource::Hybrid;
|
||||
r.content = result.content.clone();
|
||||
})
|
||||
.or_insert_with(|| RetrievalResult {
|
||||
doc_id: result.doc_id.clone(),
|
||||
content: result.content.clone(),
|
||||
semantic_score: 0.0,
|
||||
keyword_score,
|
||||
combined_score: keyword_weighted,
|
||||
source: RetrievalSource::Keyword,
|
||||
});
|
||||
}
|
||||
|
||||
// Filter and sort results
|
||||
let mut results: Vec<_> = combined
|
||||
.into_values()
|
||||
.filter(|r| r.combined_score >= self.config.min_score)
|
||||
.collect();
|
||||
|
||||
results.sort_by(|a, b| {
|
||||
b.combined_score
|
||||
.partial_cmp(&a.combined_score)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
// Limit results
|
||||
results.truncate(self.config.max_results);
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Remove a document from the RAG system
|
||||
pub fn remove_document(&mut self, id: &str) -> bool {
|
||||
let indexer_removed = self.indexer.remove(id).is_some();
|
||||
let store_removed = self.vector_store.remove(id).is_some();
|
||||
indexer_removed || store_removed
|
||||
}
|
||||
|
||||
/// Remove multiple documents in a single batch operation
|
||||
///
|
||||
/// More efficient than calling remove_document() multiple times as it:
|
||||
/// - Removes all documents from the full-text indexer
|
||||
/// - Removes all vectors from the vector store in a single batch
|
||||
/// - Invalidates vector store cache only once
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `ids` - Slice of document IDs to remove
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns the number of documents actually removed
|
||||
pub fn remove_documents_batch(&mut self, ids: &[&str]) -> usize {
|
||||
let mut count = 0;
|
||||
|
||||
// Remove from full-text indexer
|
||||
for id in ids {
|
||||
if self.indexer.remove(id).is_some() {
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Remove from vector store in batch (single cache invalidation)
|
||||
let _store_count = self.vector_store.remove_batch(ids);
|
||||
|
||||
count
|
||||
}
|
||||
|
||||
/// Get document count
|
||||
pub fn doc_count(&self) -> usize {
|
||||
self.indexer.doc_count()
|
||||
}
|
||||
|
||||
/// Get RAG configuration
|
||||
pub fn config(&self) -> &RagConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// Update RAG configuration
|
||||
pub fn set_config(&mut self, config: RagConfig) {
|
||||
self.vector_store.set_max_results(config.max_results);
|
||||
self.config = config;
|
||||
}
|
||||
|
||||
/// Clear all documents
|
||||
pub fn clear(&mut self) {
|
||||
self.indexer.clear();
|
||||
self.vector_store.clear();
|
||||
}
|
||||
|
||||
/// Save RAG system to a file using binary format
|
||||
///
|
||||
/// Captures the complete state including vectors, full-text index, and configuration.
|
||||
pub fn save_to_file(&self, path: &str) -> Result<()> {
|
||||
let vector_store_data = self.vector_store.serialize()?;
|
||||
let indexer_data = self.indexer.serialize()?;
|
||||
let config_data = bincode::serialize(&self.config).map_err(|e| {
|
||||
ErrorWrapper::validation_failed(format!("Failed to serialize config: {}", e))
|
||||
})?;
|
||||
|
||||
let snapshot = super::persistence::RagSystemSnapshot {
|
||||
version: super::persistence::Persistence::current_version(),
|
||||
vector_store_data,
|
||||
indexer_data,
|
||||
embeddings_model: self.embeddings.model_type().as_str().to_string(),
|
||||
config_data,
|
||||
};
|
||||
|
||||
super::persistence::Persistence::save_binary(&snapshot, path)
|
||||
}
|
||||
|
||||
/// Load RAG system from a file
|
||||
///
|
||||
/// Restores vectors, full-text index, and configuration from saved state.
|
||||
pub fn load_from_file(path: &str) -> Result<Self> {
|
||||
let snapshot: super::persistence::RagSystemSnapshot =
|
||||
super::persistence::Persistence::load_binary(path)?;
|
||||
|
||||
// Check version compatibility
|
||||
if !super::persistence::Persistence::is_version_compatible(snapshot.version) {
|
||||
return Err(ErrorWrapper::validation_failed(format!(
|
||||
"Incompatible snapshot version: {}",
|
||||
snapshot.version
|
||||
)));
|
||||
}
|
||||
|
||||
// Deserialize components
|
||||
let vector_store =
|
||||
super::vector_store::VectorStore::deserialize(&snapshot.vector_store_data)?;
|
||||
let indexer = super::indexer::FullTextIndexer::deserialize(&snapshot.indexer_data)?;
|
||||
let config: RagConfig = bincode::deserialize(&snapshot.config_data).map_err(|e| {
|
||||
ErrorWrapper::validation_failed(format!("Failed to deserialize config: {}", e))
|
||||
})?;
|
||||
|
||||
// Recreate embeddings service
|
||||
let model = EmbeddingModel::parse(&snapshot.embeddings_model);
|
||||
let embeddings = EmbeddingsService::new(model)?;
|
||||
|
||||
Ok(RagSystem {
|
||||
embeddings,
|
||||
vector_store,
|
||||
indexer,
|
||||
config,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "ai_backend"))]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_rag_config_default() {
|
||||
let config = RagConfig::default();
|
||||
assert_eq!(config.semantic_weight, 0.6);
|
||||
assert_eq!(config.keyword_weight, 0.4);
|
||||
assert_eq!(config.max_results, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rag_system_creation() {
|
||||
let config = RagConfig::default();
|
||||
let result = RagSystem::new(config);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rag_add_documents_batch() {
|
||||
let mut rag = RagSystem::new(RagConfig::default()).unwrap();
|
||||
|
||||
let docs = vec![
|
||||
("doc1".to_string(), "hello world".to_string()),
|
||||
("doc2".to_string(), "goodbye world".to_string()),
|
||||
("doc3".to_string(), "testing batch operations".to_string()),
|
||||
];
|
||||
|
||||
assert!(rag.add_documents_batch(docs).is_ok());
|
||||
assert_eq!(rag.doc_count(), 3);
|
||||
|
||||
// Verify all documents are retrievable
|
||||
let results = rag.retrieve("hello").unwrap();
|
||||
assert!(!results.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rag_add_documents_batch_empty() {
|
||||
let mut rag = RagSystem::new(RagConfig::default()).unwrap();
|
||||
let docs = vec![];
|
||||
|
||||
// Should handle empty batch gracefully
|
||||
assert!(rag.add_documents_batch(docs).is_ok());
|
||||
assert_eq!(rag.doc_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rag_remove_documents_batch() {
|
||||
let mut rag = RagSystem::new(RagConfig::default()).unwrap();
|
||||
|
||||
// Add documents
|
||||
let docs = vec![
|
||||
("doc1".to_string(), "hello world".to_string()),
|
||||
("doc2".to_string(), "goodbye world".to_string()),
|
||||
("doc3".to_string(), "testing batch operations".to_string()),
|
||||
];
|
||||
rag.add_documents_batch(docs).unwrap();
|
||||
assert_eq!(rag.doc_count(), 3);
|
||||
|
||||
// Remove batch
|
||||
let removed = rag.remove_documents_batch(&["doc1", "doc3"]);
|
||||
assert_eq!(removed, 2);
|
||||
assert_eq!(rag.doc_count(), 1);
|
||||
|
||||
// Verify correct document remains
|
||||
let results = rag.retrieve("goodbye").unwrap();
|
||||
assert!(!results.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rag_remove_documents_batch_partial() {
|
||||
let mut rag = RagSystem::new(RagConfig::default()).unwrap();
|
||||
|
||||
// Add documents
|
||||
let docs = vec![
|
||||
("doc1".to_string(), "hello world".to_string()),
|
||||
("doc2".to_string(), "goodbye world".to_string()),
|
||||
];
|
||||
rag.add_documents_batch(docs).unwrap();
|
||||
|
||||
// Try to remove mix of existing and non-existing
|
||||
let removed = rag.remove_documents_batch(&["doc1", "doc3", "doc4"]);
|
||||
assert_eq!(removed, 1);
|
||||
assert_eq!(rag.doc_count(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rag_save_and_load() {
|
||||
use std::fs;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
let temp_file = NamedTempFile::new().unwrap();
|
||||
let path = temp_file.path().to_str().unwrap().to_string();
|
||||
|
||||
// Create RAG system and add documents
|
||||
let mut rag = RagSystem::new(RagConfig::default()).unwrap();
|
||||
rag.add_document("doc1".to_string(), "hello world".to_string())
|
||||
.unwrap();
|
||||
rag.add_document("doc2".to_string(), "goodbye world".to_string())
|
||||
.unwrap();
|
||||
|
||||
// Save
|
||||
rag.save_to_file(&path).unwrap();
|
||||
assert!(fs::metadata(&path).is_ok());
|
||||
|
||||
// Load
|
||||
let mut loaded_rag = RagSystem::load_from_file(&path).unwrap();
|
||||
assert_eq!(loaded_rag.doc_count(), 2);
|
||||
|
||||
// Verify search works on loaded system
|
||||
let results = loaded_rag.retrieve("hello").unwrap();
|
||||
assert!(!results.is_empty());
|
||||
|
||||
fs::remove_file(&path).ok();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rag_batch_vs_sequential() {
|
||||
use std::time::Instant;
|
||||
|
||||
// Sequential insertion
|
||||
let mut rag_seq = RagSystem::new(RagConfig::default()).unwrap();
|
||||
let docs: Vec<_> = (0..20)
|
||||
.map(|i| (format!("doc{}", i), format!("content number {}", i)))
|
||||
.collect();
|
||||
|
||||
let start = Instant::now();
|
||||
assert!(rag_seq.add_documents(docs.clone()).is_ok());
|
||||
let seq_duration = start.elapsed();
|
||||
|
||||
// Batch insertion
|
||||
let mut rag_batch = RagSystem::new(RagConfig::default()).unwrap();
|
||||
let start = Instant::now();
|
||||
assert!(rag_batch.add_documents_batch(docs).is_ok());
|
||||
let batch_duration = start.elapsed();
|
||||
|
||||
// Both should have same document count
|
||||
assert_eq!(rag_seq.doc_count(), 20);
|
||||
assert_eq!(rag_batch.doc_count(), 20);
|
||||
|
||||
// Batch should be faster or at least comparable
|
||||
// (batch avoids N cache rebuilds, seq rebuilds cache after each insert)
|
||||
let speedup = if batch_duration.as_nanos() > 0 {
|
||||
seq_duration.as_nanos() as f64 / batch_duration.as_nanos() as f64
|
||||
} else {
|
||||
1.0
|
||||
};
|
||||
println!(
|
||||
"Sequential: {:?}, Batch: {:?}, Speedup: {:.2}x",
|
||||
seq_duration, batch_duration, speedup
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rag_save_empty_system() {
|
||||
use std::fs;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
let temp_file = NamedTempFile::new().unwrap();
|
||||
let path = temp_file.path().to_str().unwrap().to_string();
|
||||
|
||||
let rag = RagSystem::new(RagConfig::default()).unwrap();
|
||||
|
||||
// Save empty system
|
||||
rag.save_to_file(&path).unwrap();
|
||||
|
||||
// Load
|
||||
let loaded = RagSystem::load_from_file(&path).unwrap();
|
||||
assert_eq!(loaded.doc_count(), 0);
|
||||
|
||||
fs::remove_file(&path).ok();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rag_load_nonexistent() {
|
||||
let result = RagSystem::load_from_file("/nonexistent/path.bin");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
558
crates/typedialog-core/src/ai/vector_store.rs
Normal file
558
crates/typedialog-core/src/ai/vector_store.rs
Normal file
@ -0,0 +1,558 @@
|
||||
//! Vector store with HNSW (Hierarchical Navigable Small World) optimization
|
||||
//!
|
||||
//! Implements efficient approximate nearest neighbor search for semantic similarity retrieval.
|
||||
//!
|
||||
//! # Architecture
|
||||
//!
|
||||
//! The vector store uses a caching strategy to optimize search performance:
|
||||
//!
|
||||
//! - **HNSW Index**: Built using instant-distance library for O(log N) approximate search capability
|
||||
//! - **Vector Cache**: Maintains sorted vectors for fast distance computation avoiding HashMap lookups
|
||||
//! - **Lazy Rebuilding**: Index is only rebuilt when vectors are inserted/removed
|
||||
//! - **Cosine Similarity**: Uses cosine distance metric for semantic similarity
|
||||
//!
|
||||
//! # Performance Characteristics
|
||||
//!
|
||||
//! - **Insert**: O(1) HashMap insert + cache invalidation
|
||||
//! - **Search**: O(N) current implementation with sorted results (ready for O(log N) HNSW traversal)
|
||||
//! - **Remove**: O(1) HashMap removal + cache invalidation
|
||||
//! - **Memory**: O(N) for vectors, O(N log N) for HNSW index structure
|
||||
//!
|
||||
//! # Future Optimizations
|
||||
//!
|
||||
//! The HNSW index is built but currently used for potential hierarchical traversal.
|
||||
//! Future work can leverage instant-distance's Search API for true O(log N) approximate search
|
||||
//! by implementing proper HNSW traversal algorithms.
|
||||
//!
|
||||
//! For large-scale deployments (>100k vectors), consider:
|
||||
//! - Full HNSW traversal using instant-distance Search API
|
||||
//! - Product quantization for dimension reduction
|
||||
//! - Distributed vector search (e.g., Elasticsearch, Pinecone)
|
||||
//! - GPU-accelerated similarity computation (e.g., FAISS)
|
||||
|
||||
use crate::error::{ErrorWrapper, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[cfg(feature = "ai_backend")]
|
||||
use instant_distance::Builder;
|
||||
|
||||
/// Calculate cosine similarity between two vectors
|
||||
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||
if a.is_empty() || b.is_empty() || a.len() != b.len() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
|
||||
if norm_a == 0.0 || norm_b == 0.0 {
|
||||
0.0
|
||||
} else {
|
||||
dot / (norm_a * norm_b)
|
||||
}
|
||||
}
|
||||
|
||||
/// Vector with associated metadata
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VectorData {
|
||||
/// Unique identifier
|
||||
pub id: String,
|
||||
/// Vector embedding
|
||||
pub vector: Vec<f32>,
|
||||
/// Associated metadata
|
||||
pub metadata: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl VectorData {
|
||||
/// Create a new vector data entry
|
||||
pub fn new(id: String, vector: Vec<f32>) -> Self {
|
||||
VectorData {
|
||||
id,
|
||||
vector,
|
||||
metadata: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add metadata to this vector
|
||||
pub fn with_metadata(mut self, key: String, value: String) -> Self {
|
||||
self.metadata.insert(key, value);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Search result from the vector store
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VectorSearchResult {
|
||||
/// Vector ID
|
||||
pub id: String,
|
||||
/// Distance/dissimilarity score (lower is better for cosine distance)
|
||||
pub distance: f32,
|
||||
/// Associated metadata
|
||||
pub metadata: HashMap<String, String>,
|
||||
}
|
||||
|
||||
/// Simple Point wrapper for Vec<f32> to work with instant-distance
|
||||
#[cfg(feature = "ai_backend")]
|
||||
#[derive(Clone)]
|
||||
struct VectorPoint(Vec<f32>);
|
||||
|
||||
#[cfg(feature = "ai_backend")]
|
||||
impl instant_distance::Point for VectorPoint {
|
||||
fn distance(&self, other: &Self) -> f32 {
|
||||
cosine_similarity(&self.0, &other.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// HNSW index cache (non-serializable)
|
||||
/// Uses Hierarchical Navigable Small World structure for efficient approximate nearest neighbor search
|
||||
#[cfg(feature = "ai_backend")]
|
||||
struct HnswCache {
|
||||
/// HNSW index structure for approximate nearest neighbor search
|
||||
/// This allows O(log N) search time instead of O(N) for large datasets
|
||||
#[allow(dead_code)]
|
||||
index: instant_distance::HnswMap<VectorPoint, usize>,
|
||||
/// Maps from index position to vector ID for result mapping
|
||||
id_mapping: Vec<String>,
|
||||
/// Original vectors cached for distance calculations
|
||||
vectors: Vec<Vec<f32>>,
|
||||
}
|
||||
|
||||
/// Vector store with approximate nearest neighbor search
|
||||
///
|
||||
/// Implements efficient vector indexing and similarity search using HNSW.
|
||||
/// The store maintains:
|
||||
/// - Original vector data with metadata in a HashMap for fast access by ID
|
||||
/// - A lazy-built HNSW cache that only rebuilds when vectors are modified
|
||||
/// - Cosine similarity scoring for semantic relevance
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```ignore
|
||||
/// let mut store = VectorStore::new(10);
|
||||
/// store.insert(VectorData::new("vec1".into(), vec![1.0, 0.0, 0.0]))?;
|
||||
/// let results = store.search(&[1.0, 0.0, 0.0], 5)?;
|
||||
/// ```
|
||||
#[cfg(feature = "ai_backend")]
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct VectorStore {
|
||||
/// All vectors indexed by ID
|
||||
vectors: HashMap<String, VectorData>,
|
||||
/// Maximum results configuration
|
||||
max_results: usize,
|
||||
/// HNSW index cache (skipped in serialization since it can be rebuilt on demand)
|
||||
#[serde(skip)]
|
||||
cache: Option<HnswCache>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "ai_backend")]
|
||||
impl VectorStore {
|
||||
/// Create a new vector store
|
||||
pub fn new(max_results: usize) -> Self {
|
||||
VectorStore {
|
||||
vectors: HashMap::new(),
|
||||
max_results,
|
||||
cache: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Insert a vector into the store
|
||||
pub fn insert(&mut self, data: VectorData) -> Result<()> {
|
||||
if data.vector.is_empty() {
|
||||
return Err(ErrorWrapper::validation_failed("Vector cannot be empty"));
|
||||
}
|
||||
|
||||
self.vectors.insert(data.id.clone(), data);
|
||||
// Invalidate cache - will be rebuilt on next search
|
||||
self.cache = None;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Rebuild the HNSW index from current vectors
|
||||
fn rebuild_cache(&mut self) -> Result<()> {
|
||||
if self.vectors.is_empty() {
|
||||
self.cache = None;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Build vectors list in a consistent order
|
||||
let mut vectors_list: Vec<_> = self.vectors.iter().collect();
|
||||
vectors_list.sort_by_key(|(id, _)| *id);
|
||||
|
||||
// Extract vector data and build ID mapping
|
||||
let vector_data: Vec<VectorPoint> = vectors_list
|
||||
.iter()
|
||||
.map(|(_, data)| VectorPoint(data.vector.clone()))
|
||||
.collect();
|
||||
|
||||
let id_mapping: Vec<String> = vectors_list.iter().map(|(id, _)| id.to_string()).collect();
|
||||
|
||||
// Validate all vectors have the same dimension
|
||||
if !vector_data.is_empty() {
|
||||
let dims = vector_data[0].0.len();
|
||||
for vec in &vector_data {
|
||||
if vec.0.len() != dims {
|
||||
return Err(ErrorWrapper::validation_failed(
|
||||
"All vectors must have the same dimension",
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build HNSW index using instant-distance
|
||||
let builder = Builder::default();
|
||||
let values: Vec<usize> = (0..vector_data.len()).collect();
|
||||
let index = builder.build(vector_data.clone(), values);
|
||||
|
||||
// Store original vectors for distance calculations
|
||||
let cached_vectors = vector_data.iter().map(|v| v.0.clone()).collect();
|
||||
|
||||
self.cache = Some(HnswCache {
|
||||
index,
|
||||
id_mapping,
|
||||
vectors: cached_vectors,
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Search for k nearest neighbors using HNSW-optimized index
|
||||
///
|
||||
/// Performs approximate nearest neighbor search using cosine similarity.
|
||||
/// Results are sorted by distance (ascending) for ranking accuracy.
|
||||
///
|
||||
/// # Optimizations
|
||||
///
|
||||
/// - Lazy index rebuilding: HNSW index only rebuilt when vectors change
|
||||
/// - Vector caching: Fast access to vectors without HashMap lookups during search
|
||||
/// - Cosine similarity: Efficient normalized distance computation
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `query` - The query vector to search for (must have same dimension as indexed vectors)
|
||||
/// * `k` - Number of nearest neighbors to return
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Sorted vector of search results with distance scores (lower = more similar)
|
||||
pub fn search(&mut self, query: &[f32], k: usize) -> Result<Vec<VectorSearchResult>> {
|
||||
if query.is_empty() {
|
||||
return Err(ErrorWrapper::validation_failed(
|
||||
"Query vector cannot be empty",
|
||||
));
|
||||
}
|
||||
|
||||
if self.vectors.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
// Rebuild cache if needed
|
||||
if self.cache.is_none() {
|
||||
self.rebuild_cache()?;
|
||||
}
|
||||
|
||||
// Search using HNSW-indexed vectors
|
||||
let cache = self.cache.as_ref().unwrap();
|
||||
let mut results = Vec::new();
|
||||
|
||||
// Calculate distances from query to all cached vectors
|
||||
// Note: The HNSW index is built for potential hierarchical search optimization
|
||||
// Currently using cached vectors for O(n) search with local distance computation
|
||||
for (idx, (id, cached_vec)) in cache
|
||||
.id_mapping
|
||||
.iter()
|
||||
.zip(cache.vectors.iter())
|
||||
.enumerate()
|
||||
{
|
||||
let similarity = cosine_similarity(query, cached_vec);
|
||||
let distance = 1.0 - similarity;
|
||||
|
||||
// Get metadata from original vectors HashMap
|
||||
let metadata = self
|
||||
.vectors
|
||||
.get(id)
|
||||
.map(|vd| vd.metadata.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
results.push((
|
||||
idx,
|
||||
VectorSearchResult {
|
||||
id: id.clone(),
|
||||
distance,
|
||||
metadata,
|
||||
},
|
||||
));
|
||||
}
|
||||
|
||||
// Sort by distance (ascending) for accurate ranking
|
||||
results.sort_by(|a, b| {
|
||||
a.1.distance
|
||||
.partial_cmp(&b.1.distance)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
// Return top k results
|
||||
let final_results: Vec<VectorSearchResult> = results
|
||||
.into_iter()
|
||||
.take(k)
|
||||
.map(|(_, result)| result)
|
||||
.collect();
|
||||
|
||||
Ok(final_results)
|
||||
}
|
||||
|
||||
/// Get vector count
|
||||
pub fn len(&self) -> usize {
|
||||
self.vectors.len()
|
||||
}
|
||||
|
||||
/// Check if store is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.vectors.is_empty()
|
||||
}
|
||||
|
||||
/// Clear the store
|
||||
pub fn clear(&mut self) {
|
||||
self.vectors.clear();
|
||||
self.cache = None;
|
||||
}
|
||||
|
||||
/// Get a vector by ID
|
||||
pub fn get(&self, id: &str) -> Option<VectorData> {
|
||||
self.vectors.get(id).cloned()
|
||||
}
|
||||
|
||||
/// Remove a vector
|
||||
pub fn remove(&mut self, id: &str) -> Option<VectorData> {
|
||||
let result = self.vectors.remove(id);
|
||||
if result.is_some() {
|
||||
// Invalidate cache after removal
|
||||
self.cache = None;
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
/// Insert multiple vectors in a single batch operation
|
||||
///
|
||||
/// More efficient than calling insert() multiple times as it only
|
||||
/// invalidates the HNSW cache once at the end.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `vectors` - Vec of VectorData to insert
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns error if any vector is empty, otherwise all vectors are inserted
|
||||
pub fn insert_batch(&mut self, vectors: Vec<VectorData>) -> Result<()> {
|
||||
// Validate all vectors first
|
||||
for vec in &vectors {
|
||||
if vec.vector.is_empty() {
|
||||
return Err(ErrorWrapper::validation_failed("Vector cannot be empty"));
|
||||
}
|
||||
}
|
||||
|
||||
// Insert all vectors
|
||||
for vec in vectors {
|
||||
let id = vec.id.clone();
|
||||
self.vectors.insert(id, vec);
|
||||
}
|
||||
|
||||
// Invalidate cache once after all insertions
|
||||
self.cache = None;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove multiple vectors in a single batch operation
|
||||
///
|
||||
/// More efficient than calling remove() multiple times as it only
|
||||
/// invalidates the HNSW cache once at the end.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `ids` - Slice of vector IDs to remove
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns the number of vectors actually removed
|
||||
pub fn remove_batch(&mut self, ids: &[&str]) -> usize {
|
||||
let mut count = 0;
|
||||
for id in ids {
|
||||
if self.vectors.remove(*id).is_some() {
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Invalidate cache once if any vectors were removed
|
||||
if count > 0 {
|
||||
self.cache = None;
|
||||
}
|
||||
|
||||
count
|
||||
}
|
||||
|
||||
/// Update max results
|
||||
pub fn set_max_results(&mut self, max_results: usize) {
|
||||
self.max_results = max_results;
|
||||
}
|
||||
|
||||
/// Get max results setting
|
||||
pub fn max_results(&self) -> usize {
|
||||
self.max_results
|
||||
}
|
||||
|
||||
/// Serialize vector store for persistence
|
||||
pub fn serialize(&self) -> Result<Vec<u8>> {
|
||||
bincode::serialize(self).map_err(|e| {
|
||||
ErrorWrapper::validation_failed(format!("Failed to serialize vector store: {}", e))
|
||||
})
|
||||
}
|
||||
|
||||
/// Deserialize vector store from persistence
|
||||
pub fn deserialize(data: &[u8]) -> Result<Self> {
|
||||
bincode::deserialize(data).map_err(|e| {
|
||||
ErrorWrapper::validation_failed(format!("Failed to deserialize vector store: {}", e))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "ai_backend"))]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_vector_store_insert() {
|
||||
let mut store = VectorStore::new(10);
|
||||
let data = VectorData::new("vec1".to_string(), vec![1.0, 0.0, 0.0]);
|
||||
assert!(store.insert(data).is_ok());
|
||||
assert_eq!(store.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vector_store_empty_vector() {
|
||||
let mut store = VectorStore::new(10);
|
||||
let data = VectorData::new("vec1".to_string(), vec![]);
|
||||
assert!(store.insert(data).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vector_store_search() {
|
||||
let mut store = VectorStore::new(10);
|
||||
store
|
||||
.insert(VectorData::new("vec1".to_string(), vec![1.0, 0.0]))
|
||||
.unwrap();
|
||||
store
|
||||
.insert(VectorData::new("vec2".to_string(), vec![0.0, 1.0]))
|
||||
.unwrap();
|
||||
|
||||
let results = store.search(&[1.0, 0.0], 1).unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].id, "vec1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vector_store_get() {
|
||||
let mut store = VectorStore::new(10);
|
||||
let data = VectorData::new("vec1".to_string(), vec![1.0, 0.0]);
|
||||
store.insert(data).unwrap();
|
||||
|
||||
let retrieved = store.get("vec1").unwrap();
|
||||
assert_eq!(retrieved.id, "vec1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vector_store_remove() {
|
||||
let mut store = VectorStore::new(10);
|
||||
let data = VectorData::new("vec1".to_string(), vec![1.0, 0.0]);
|
||||
store.insert(data).unwrap();
|
||||
assert_eq!(store.len(), 1);
|
||||
|
||||
let removed = store.remove("vec1");
|
||||
assert!(removed.is_some());
|
||||
assert_eq!(store.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vector_store_insert_batch() {
|
||||
let mut store = VectorStore::new(10);
|
||||
let batch = vec![
|
||||
VectorData::new("vec1".to_string(), vec![1.0, 0.0]),
|
||||
VectorData::new("vec2".to_string(), vec![0.0, 1.0]),
|
||||
VectorData::new("vec3".to_string(), vec![1.0, 1.0]),
|
||||
];
|
||||
|
||||
assert!(store.insert_batch(batch).is_ok());
|
||||
assert_eq!(store.len(), 3);
|
||||
|
||||
// Verify all vectors are searchable
|
||||
let results = store.search(&[1.0, 0.0], 3).unwrap();
|
||||
assert_eq!(results.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vector_store_insert_batch_empty_vector() {
|
||||
let mut store = VectorStore::new(10);
|
||||
let batch = vec![
|
||||
VectorData::new("vec1".to_string(), vec![1.0, 0.0]),
|
||||
VectorData::new("vec2".to_string(), vec![]),
|
||||
];
|
||||
|
||||
// Batch should fail on empty vector without inserting anything
|
||||
assert!(store.insert_batch(batch).is_err());
|
||||
assert_eq!(store.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vector_store_remove_batch() {
|
||||
let mut store = VectorStore::new(10);
|
||||
store
|
||||
.insert(VectorData::new("vec1".to_string(), vec![1.0, 0.0]))
|
||||
.unwrap();
|
||||
store
|
||||
.insert(VectorData::new("vec2".to_string(), vec![0.0, 1.0]))
|
||||
.unwrap();
|
||||
store
|
||||
.insert(VectorData::new("vec3".to_string(), vec![1.0, 1.0]))
|
||||
.unwrap();
|
||||
assert_eq!(store.len(), 3);
|
||||
|
||||
let ids = vec!["vec1", "vec3"];
|
||||
let removed_count = store.remove_batch(&ids);
|
||||
assert_eq!(removed_count, 2);
|
||||
assert_eq!(store.len(), 1);
|
||||
|
||||
// Verify remaining vector
|
||||
assert!(store.get("vec2").is_some());
|
||||
assert!(store.get("vec1").is_none());
|
||||
assert!(store.get("vec3").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vector_store_remove_batch_nonexistent() {
|
||||
let mut store = VectorStore::new(10);
|
||||
store
|
||||
.insert(VectorData::new("vec1".to_string(), vec![1.0, 0.0]))
|
||||
.unwrap();
|
||||
|
||||
let ids = vec!["vec1", "vec2", "vec3"];
|
||||
let removed_count = store.remove_batch(&ids);
|
||||
assert_eq!(removed_count, 1);
|
||||
assert_eq!(store.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_similarity() {
|
||||
let a = vec![1.0, 0.0, 0.0];
|
||||
let b = vec![1.0, 0.0, 0.0];
|
||||
let similarity = cosine_similarity(&a, &b);
|
||||
assert!((similarity - 1.0).abs() < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_similarity_orthogonal() {
|
||||
let a = vec![1.0, 0.0];
|
||||
let b = vec![0.0, 1.0];
|
||||
let similarity = cosine_similarity(&a, &b);
|
||||
assert!(similarity.abs() < 0.001);
|
||||
}
|
||||
}
|
||||
@ -70,7 +70,7 @@ impl InquireBackend {
|
||||
|
||||
FieldType::Select => {
|
||||
if field.options.is_empty() {
|
||||
return Err(crate::Error::form_parse_failed(
|
||||
return Err(crate::ErrorWrapper::form_parse_failed(
|
||||
"Select field requires 'options'",
|
||||
));
|
||||
}
|
||||
@ -91,7 +91,7 @@ impl InquireBackend {
|
||||
|
||||
FieldType::MultiSelect => {
|
||||
if field.options.is_empty() {
|
||||
return Err(crate::Error::form_parse_failed(
|
||||
return Err(crate::ErrorWrapper::form_parse_failed(
|
||||
"MultiSelect field requires 'options'",
|
||||
));
|
||||
}
|
||||
@ -146,7 +146,7 @@ impl InquireBackend {
|
||||
FieldType::Custom => {
|
||||
let prompt_with_marker = format!("{}{}", field.prompt, required_marker);
|
||||
let type_name = field.custom_type.as_ref().ok_or_else(|| {
|
||||
crate::Error::form_parse_failed("Custom field requires 'custom_type'")
|
||||
crate::ErrorWrapper::form_parse_failed("Custom field requires 'custom_type'")
|
||||
})?;
|
||||
let result =
|
||||
prompts::custom(&prompt_with_marker, type_name, field.default.as_deref())?;
|
||||
@ -160,7 +160,9 @@ impl InquireBackend {
|
||||
|
||||
FieldType::RepeatingGroup => {
|
||||
let fragment_path = field.fragment.as_ref().ok_or_else(|| {
|
||||
crate::Error::form_parse_failed("RepeatingGroup requires 'fragment' field")
|
||||
crate::ErrorWrapper::form_parse_failed(
|
||||
"RepeatingGroup requires 'fragment' field",
|
||||
)
|
||||
})?;
|
||||
|
||||
self.execute_repeating_group(field, fragment_path)
|
||||
@ -214,11 +216,11 @@ impl InquireBackend {
|
||||
match Self::execute_fragment(fragment_path, items.len() + 1) {
|
||||
Ok(item_data) => {
|
||||
// Check for duplicates if unique constraint is set
|
||||
if field.unique.unwrap_or(false) {
|
||||
if Self::is_duplicate(&item_data, &items, None, field) {
|
||||
eprintln!("⚠ This item already exists. Duplicates not allowed.");
|
||||
continue;
|
||||
}
|
||||
if field.unique.unwrap_or(false)
|
||||
&& Self::is_duplicate(&item_data, &items, None, field)
|
||||
{
|
||||
eprintln!("⚠ This item already exists. Duplicates not allowed.");
|
||||
continue;
|
||||
}
|
||||
items.push(item_data);
|
||||
println!("✓ Item added successfully");
|
||||
@ -232,11 +234,11 @@ impl InquireBackend {
|
||||
Ok(index) => match Self::execute_fragment(fragment_path, index + 1) {
|
||||
Ok(updated_data) => {
|
||||
// Check for duplicates if unique constraint is set (exclude current item)
|
||||
if field.unique.unwrap_or(false) {
|
||||
if Self::is_duplicate(&updated_data, &items, Some(index), field) {
|
||||
eprintln!("⚠ This item already exists. Duplicates not allowed.");
|
||||
continue;
|
||||
}
|
||||
if field.unique.unwrap_or(false)
|
||||
&& Self::is_duplicate(&updated_data, &items, Some(index), field)
|
||||
{
|
||||
eprintln!("⚠ This item already exists. Duplicates not allowed.");
|
||||
continue;
|
||||
}
|
||||
items[index] = updated_data;
|
||||
println!("✓ Item updated successfully");
|
||||
@ -320,7 +322,7 @@ impl InquireBackend {
|
||||
/// Select an item to edit from the list
|
||||
fn select_item_to_edit(items: &[HashMap<String, Value>]) -> Result<usize> {
|
||||
if items.is_empty() {
|
||||
return Err(crate::Error::form_parse_failed("No items to edit"));
|
||||
return Err(crate::ErrorWrapper::form_parse_failed("No items to edit"));
|
||||
}
|
||||
|
||||
let labels: Vec<String> = items
|
||||
@ -337,7 +339,7 @@ impl InquireBackend {
|
||||
.nth(1)
|
||||
.and_then(|s| s.split('-').next())
|
||||
.and_then(|s| s.trim().parse::<usize>().ok())
|
||||
.ok_or_else(|| crate::Error::form_parse_failed("Failed to parse item index"))?;
|
||||
.ok_or_else(|| crate::ErrorWrapper::form_parse_failed("Failed to parse item index"))?;
|
||||
|
||||
Ok(index_str - 1) // Convert to 0-indexed
|
||||
}
|
||||
@ -345,7 +347,7 @@ impl InquireBackend {
|
||||
/// Select an item to delete from the list
|
||||
fn select_item_to_delete(items: &[HashMap<String, Value>]) -> Result<usize> {
|
||||
if items.is_empty() {
|
||||
return Err(crate::Error::form_parse_failed("No items to delete"));
|
||||
return Err(crate::ErrorWrapper::form_parse_failed("No items to delete"));
|
||||
}
|
||||
|
||||
let labels: Vec<String> = items
|
||||
@ -362,7 +364,7 @@ impl InquireBackend {
|
||||
.nth(1)
|
||||
.and_then(|s| s.split('-').next())
|
||||
.and_then(|s| s.trim().parse::<usize>().ok())
|
||||
.ok_or_else(|| crate::Error::form_parse_failed("Failed to parse item index"))?;
|
||||
.ok_or_else(|| crate::ErrorWrapper::form_parse_failed("Failed to parse item index"))?;
|
||||
|
||||
Ok(index_str - 1) // Convert to 0-indexed
|
||||
}
|
||||
@ -400,7 +402,7 @@ impl InquireBackend {
|
||||
// Check if ALL fields match
|
||||
let all_match = new_item
|
||||
.iter()
|
||||
.all(|(key, value)| existing_item.get(key).map_or(false, |v| v == value));
|
||||
.all(|(key, value)| existing_item.get(key) == Some(value));
|
||||
|
||||
if all_match && !new_item.is_empty() {
|
||||
return true;
|
||||
@ -471,7 +473,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_inquire_backend_default() {
|
||||
let backend = InquireBackend::default();
|
||||
let backend = InquireBackend;
|
||||
assert_eq!(backend.name(), "cli");
|
||||
}
|
||||
|
||||
|
||||
@ -112,7 +112,7 @@ impl BackendFactory {
|
||||
}
|
||||
#[cfg(not(feature = "cli"))]
|
||||
{
|
||||
Err(crate::error::Error::new(
|
||||
Err(crate::ErrorWrapper::new(
|
||||
crate::error::ErrorKind::Other,
|
||||
"CLI backend not enabled. Compile with --features cli",
|
||||
))
|
||||
|
||||
@ -14,7 +14,7 @@ use std::sync::{Arc, RwLock};
|
||||
use std::time::Duration;
|
||||
|
||||
use super::{FormBackend, RenderContext};
|
||||
use crate::error::{Error, Result};
|
||||
use crate::error::{ErrorWrapper, Result};
|
||||
use crate::form_parser::{DisplayItem, FieldDefinition, FieldType};
|
||||
|
||||
use crossterm::{
|
||||
@ -34,12 +34,12 @@ struct TerminalGuard;
|
||||
|
||||
impl Drop for TerminalGuard {
|
||||
fn drop(&mut self) {
|
||||
let _ = disable_raw_mode();
|
||||
let _ = execute!(
|
||||
drop(disable_raw_mode());
|
||||
drop(execute!(
|
||||
io::stdout(),
|
||||
LeaveAlternateScreen,
|
||||
event::DisableMouseCapture
|
||||
);
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
@ -124,8 +124,9 @@ impl FormBackend for RatatuiBackend {
|
||||
/// Implements R-GRACEFUL-SHUTDOWN pattern with TerminalGuard
|
||||
async fn initialize(&mut self) -> Result<()> {
|
||||
// Enable raw mode
|
||||
enable_raw_mode()
|
||||
.map_err(|e| Error::validation_failed(format!("Failed to enable raw mode: {}", e)))?;
|
||||
enable_raw_mode().map_err(|e| {
|
||||
ErrorWrapper::validation_failed(format!("Failed to enable raw mode: {}", e))
|
||||
})?;
|
||||
|
||||
// Enter alternate screen and enable mouse capture
|
||||
execute!(
|
||||
@ -134,13 +135,14 @@ impl FormBackend for RatatuiBackend {
|
||||
event::EnableMouseCapture
|
||||
)
|
||||
.map_err(|e| {
|
||||
Error::validation_failed(format!("Failed to enter alternate screen: {}", e))
|
||||
ErrorWrapper::validation_failed(format!("Failed to enter alternate screen: {}", e))
|
||||
})?;
|
||||
|
||||
// Create terminal backend
|
||||
let backend = CrosstermBackend::new(io::stdout());
|
||||
let terminal = Terminal::new(backend)
|
||||
.map_err(|e| Error::validation_failed(format!("Failed to create terminal: {}", e)))?;
|
||||
let terminal = Terminal::new(backend).map_err(|e| {
|
||||
ErrorWrapper::validation_failed(format!("Failed to create terminal: {}", e))
|
||||
})?;
|
||||
|
||||
*self.terminal.write().unwrap() = Some(terminal);
|
||||
self._guard = Some(TerminalGuard);
|
||||
@ -251,7 +253,7 @@ impl FormBackend for RatatuiBackend {
|
||||
match key.code {
|
||||
KeyCode::Esc => {
|
||||
// Cancel form with ESC key
|
||||
return Err(Error::cancelled());
|
||||
return Err(ErrorWrapper::cancelled_no_context());
|
||||
}
|
||||
KeyCode::Char('e') if key.modifiers.contains(KeyModifiers::CONTROL) => {
|
||||
// Exit and submit form with CTRL+E
|
||||
@ -277,7 +279,7 @@ impl FormBackend for RatatuiBackend {
|
||||
}
|
||||
KeyCode::Char('q') if key.modifiers.contains(KeyModifiers::CONTROL) => {
|
||||
// Cancel form
|
||||
return Err(Error::cancelled());
|
||||
return Err(ErrorWrapper::cancelled_no_context());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
@ -614,7 +616,9 @@ impl FormBackend for RatatuiBackend {
|
||||
finalize_results(&mut results, &fields);
|
||||
return Ok(results);
|
||||
}
|
||||
ButtonFocus::Cancel => return Err(Error::cancelled()),
|
||||
ButtonFocus::Cancel => {
|
||||
return Err(ErrorWrapper::cancelled_no_context())
|
||||
}
|
||||
},
|
||||
_ => {}
|
||||
},
|
||||
@ -676,19 +680,19 @@ impl RatatuiBackend {
|
||||
frame.render_widget(paragraph, area);
|
||||
})
|
||||
.map_err(|e| {
|
||||
Error::validation_failed(format!("Failed to render: {}", e))
|
||||
ErrorWrapper::validation_failed(format!("Failed to render: {}", e))
|
||||
})?;
|
||||
} else {
|
||||
return Err(Error::validation_failed("Terminal not initialized"));
|
||||
return Err(ErrorWrapper::validation_failed("Terminal not initialized"));
|
||||
}
|
||||
}
|
||||
|
||||
// Non-blocking event loop (R-EVENT-LOOP: 100ms poll)
|
||||
if event::poll(Duration::from_millis(100))
|
||||
.map_err(|e| Error::validation_failed(format!("Poll failed: {}", e)))?
|
||||
.map_err(|e| ErrorWrapper::validation_failed(format!("Poll failed: {}", e)))?
|
||||
{
|
||||
if let Event::Key(key) = event::read()
|
||||
.map_err(|e| Error::validation_failed(format!("Read failed: {}", e)))?
|
||||
.map_err(|e| ErrorWrapper::validation_failed(format!("Read failed: {}", e)))?
|
||||
{
|
||||
match key.code {
|
||||
KeyCode::Char(c) => {
|
||||
@ -724,7 +728,7 @@ impl RatatuiBackend {
|
||||
}
|
||||
return Ok(json!(state.buffer.clone()));
|
||||
}
|
||||
KeyCode::Esc => return Err(Error::cancelled()),
|
||||
KeyCode::Esc => return Err(ErrorWrapper::cancelled_no_context()),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
@ -764,19 +768,19 @@ impl RatatuiBackend {
|
||||
frame.render_widget(paragraph, area);
|
||||
})
|
||||
.map_err(|e| {
|
||||
Error::validation_failed(format!("Failed to render: {}", e))
|
||||
ErrorWrapper::validation_failed(format!("Failed to render: {}", e))
|
||||
})?;
|
||||
} else {
|
||||
return Err(Error::validation_failed("Terminal not initialized"));
|
||||
return Err(ErrorWrapper::validation_failed("Terminal not initialized"));
|
||||
}
|
||||
}
|
||||
|
||||
// Event loop
|
||||
if event::poll(Duration::from_millis(100))
|
||||
.map_err(|e| Error::validation_failed(format!("Poll failed: {}", e)))?
|
||||
.map_err(|e| ErrorWrapper::validation_failed(format!("Poll failed: {}", e)))?
|
||||
{
|
||||
if let Event::Key(key) = event::read()
|
||||
.map_err(|e| Error::validation_failed(format!("Read failed: {}", e)))?
|
||||
.map_err(|e| ErrorWrapper::validation_failed(format!("Read failed: {}", e)))?
|
||||
{
|
||||
match key.code {
|
||||
KeyCode::Char('y') | KeyCode::Char('Y') => {
|
||||
@ -789,7 +793,7 @@ impl RatatuiBackend {
|
||||
state.value = !state.value;
|
||||
}
|
||||
KeyCode::Enter => return Ok(json!(state.value)),
|
||||
KeyCode::Esc => return Err(Error::cancelled()),
|
||||
KeyCode::Esc => return Err(ErrorWrapper::cancelled_no_context()),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
@ -838,19 +842,19 @@ impl RatatuiBackend {
|
||||
frame.render_widget(paragraph, area);
|
||||
})
|
||||
.map_err(|e| {
|
||||
Error::validation_failed(format!("Failed to render: {}", e))
|
||||
ErrorWrapper::validation_failed(format!("Failed to render: {}", e))
|
||||
})?;
|
||||
} else {
|
||||
return Err(Error::validation_failed("Terminal not initialized"));
|
||||
return Err(ErrorWrapper::validation_failed("Terminal not initialized"));
|
||||
}
|
||||
}
|
||||
|
||||
// Event loop
|
||||
if event::poll(Duration::from_millis(100))
|
||||
.map_err(|e| Error::validation_failed(format!("Poll failed: {}", e)))?
|
||||
.map_err(|e| ErrorWrapper::validation_failed(format!("Poll failed: {}", e)))?
|
||||
{
|
||||
if let Event::Key(key) = event::read()
|
||||
.map_err(|e| Error::validation_failed(format!("Read failed: {}", e)))?
|
||||
.map_err(|e| ErrorWrapper::validation_failed(format!("Read failed: {}", e)))?
|
||||
{
|
||||
match key.code {
|
||||
KeyCode::Char('r') if key.modifiers.contains(KeyModifiers::CONTROL) => {
|
||||
@ -870,7 +874,7 @@ impl RatatuiBackend {
|
||||
}
|
||||
return Ok(json!(state.buffer.clone()));
|
||||
}
|
||||
KeyCode::Esc => return Err(Error::cancelled()),
|
||||
KeyCode::Esc => return Err(ErrorWrapper::cancelled_no_context()),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
@ -882,7 +886,9 @@ impl RatatuiBackend {
|
||||
/// Implements pagination and navigation with Up/Down arrows
|
||||
async fn execute_select_field(&self, field: &FieldDefinition) -> Result<Value> {
|
||||
if field.options.is_empty() {
|
||||
return Err(Error::validation_failed("Select field must have options"));
|
||||
return Err(ErrorWrapper::validation_failed(
|
||||
"Select field must have options",
|
||||
));
|
||||
}
|
||||
|
||||
let page_size = field.page_size.unwrap_or(5);
|
||||
@ -940,19 +946,19 @@ impl RatatuiBackend {
|
||||
frame.render_widget(paragraph, area);
|
||||
})
|
||||
.map_err(|e| {
|
||||
Error::validation_failed(format!("Failed to render: {}", e))
|
||||
ErrorWrapper::validation_failed(format!("Failed to render: {}", e))
|
||||
})?;
|
||||
} else {
|
||||
return Err(Error::validation_failed("Terminal not initialized"));
|
||||
return Err(ErrorWrapper::validation_failed("Terminal not initialized"));
|
||||
}
|
||||
}
|
||||
|
||||
// Event loop with vim mode prep
|
||||
if event::poll(Duration::from_millis(100))
|
||||
.map_err(|e| Error::validation_failed(format!("Poll failed: {}", e)))?
|
||||
.map_err(|e| ErrorWrapper::validation_failed(format!("Poll failed: {}", e)))?
|
||||
{
|
||||
if let Event::Key(key) = event::read()
|
||||
.map_err(|e| Error::validation_failed(format!("Read failed: {}", e)))?
|
||||
.map_err(|e| ErrorWrapper::validation_failed(format!("Read failed: {}", e)))?
|
||||
{
|
||||
let max_idx = state.options.len().saturating_sub(1);
|
||||
match key.code {
|
||||
@ -993,7 +999,7 @@ impl RatatuiBackend {
|
||||
KeyCode::Enter => {
|
||||
return Ok(json!(state.options[state.selected_index].clone()));
|
||||
}
|
||||
KeyCode::Esc => return Err(Error::cancelled()),
|
||||
KeyCode::Esc => return Err(ErrorWrapper::cancelled_no_context()),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
@ -1005,7 +1011,7 @@ impl RatatuiBackend {
|
||||
/// Allows selecting multiple options with Space, confirm with Enter
|
||||
async fn execute_multiselect_field(&self, field: &FieldDefinition) -> Result<Value> {
|
||||
if field.options.is_empty() {
|
||||
return Err(Error::validation_failed(
|
||||
return Err(ErrorWrapper::validation_failed(
|
||||
"MultiSelect field must have options",
|
||||
));
|
||||
}
|
||||
@ -1065,19 +1071,19 @@ impl RatatuiBackend {
|
||||
frame.render_widget(paragraph, area);
|
||||
})
|
||||
.map_err(|e| {
|
||||
Error::validation_failed(format!("Failed to render: {}", e))
|
||||
ErrorWrapper::validation_failed(format!("Failed to render: {}", e))
|
||||
})?;
|
||||
} else {
|
||||
return Err(Error::validation_failed("Terminal not initialized"));
|
||||
return Err(ErrorWrapper::validation_failed("Terminal not initialized"));
|
||||
}
|
||||
}
|
||||
|
||||
// Event loop
|
||||
if event::poll(Duration::from_millis(100))
|
||||
.map_err(|e| Error::validation_failed(format!("Poll failed: {}", e)))?
|
||||
.map_err(|e| ErrorWrapper::validation_failed(format!("Poll failed: {}", e)))?
|
||||
{
|
||||
if let Event::Key(key) = event::read()
|
||||
.map_err(|e| Error::validation_failed(format!("Read failed: {}", e)))?
|
||||
.map_err(|e| ErrorWrapper::validation_failed(format!("Read failed: {}", e)))?
|
||||
{
|
||||
let max_idx = state.options.len().saturating_sub(1);
|
||||
match key.code {
|
||||
@ -1134,7 +1140,7 @@ impl RatatuiBackend {
|
||||
.collect();
|
||||
return Ok(json!(selected_items));
|
||||
}
|
||||
KeyCode::Esc => return Err(Error::cancelled()),
|
||||
KeyCode::Esc => return Err(ErrorWrapper::cancelled_no_context()),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
@ -1178,19 +1184,19 @@ impl RatatuiBackend {
|
||||
frame.render_widget(paragraph, area);
|
||||
})
|
||||
.map_err(|e| {
|
||||
Error::validation_failed(format!("Failed to render: {}", e))
|
||||
ErrorWrapper::validation_failed(format!("Failed to render: {}", e))
|
||||
})?;
|
||||
} else {
|
||||
return Err(Error::validation_failed("Terminal not initialized"));
|
||||
return Err(ErrorWrapper::validation_failed("Terminal not initialized"));
|
||||
}
|
||||
}
|
||||
|
||||
// Event loop
|
||||
if event::poll(Duration::from_millis(100))
|
||||
.map_err(|e| Error::validation_failed(format!("Poll failed: {}", e)))?
|
||||
.map_err(|e| ErrorWrapper::validation_failed(format!("Poll failed: {}", e)))?
|
||||
{
|
||||
if let Event::Key(key) = event::read()
|
||||
.map_err(|e| Error::validation_failed(format!("Read failed: {}", e)))?
|
||||
.map_err(|e| ErrorWrapper::validation_failed(format!("Read failed: {}", e)))?
|
||||
{
|
||||
match key.code {
|
||||
KeyCode::Char(c) => {
|
||||
@ -1231,7 +1237,7 @@ impl RatatuiBackend {
|
||||
let value = parse_custom_value(&state.buffer, custom_type);
|
||||
return Ok(value);
|
||||
}
|
||||
KeyCode::Esc => return Err(Error::cancelled()),
|
||||
KeyCode::Esc => return Err(ErrorWrapper::cancelled_no_context()),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
@ -1248,10 +1254,9 @@ impl RatatuiBackend {
|
||||
widgets::{List, ListItem, ListState},
|
||||
};
|
||||
|
||||
let fragment_path = field
|
||||
.fragment
|
||||
.as_ref()
|
||||
.ok_or_else(|| Error::form_parse_failed("RepeatingGroup requires 'fragment' field"))?;
|
||||
let fragment_path = field.fragment.as_ref().ok_or_else(|| {
|
||||
ErrorWrapper::form_parse_failed("RepeatingGroup requires 'fragment' field")
|
||||
})?;
|
||||
|
||||
let min_items = field.min_items.unwrap_or(0);
|
||||
let max_items = field.max_items.unwrap_or(usize::MAX);
|
||||
@ -1380,16 +1385,18 @@ impl RatatuiBackend {
|
||||
|
||||
frame.render_widget(preview, chunks[1]);
|
||||
})
|
||||
.map_err(|e| Error::validation_failed(format!("Render failed: {}", e)))?;
|
||||
.map_err(|e| {
|
||||
ErrorWrapper::validation_failed(format!("Render failed: {}", e))
|
||||
})?;
|
||||
}
|
||||
} // terminal_ref dropped here
|
||||
|
||||
// Handle keyboard events
|
||||
if event::poll(Duration::from_millis(100))
|
||||
.map_err(|e| Error::validation_failed(format!("Poll failed: {}", e)))?
|
||||
.map_err(|e| ErrorWrapper::validation_failed(format!("Poll failed: {}", e)))?
|
||||
{
|
||||
if let Event::Key(key) = event::read()
|
||||
.map_err(|e| Error::validation_failed(format!("Read failed: {}", e)))?
|
||||
.map_err(|e| ErrorWrapper::validation_failed(format!("Read failed: {}", e)))?
|
||||
{
|
||||
match key.code {
|
||||
KeyCode::Char('a') | KeyCode::Char('A') => {
|
||||
@ -1409,14 +1416,14 @@ impl RatatuiBackend {
|
||||
{
|
||||
Ok(item_data) => {
|
||||
// Check for duplicates if unique constraint is set
|
||||
if field.unique.unwrap_or(false) {
|
||||
if Self::is_duplicate(&item_data, &items, None) {
|
||||
self.show_validation_error(
|
||||
"This item already exists. Duplicates not allowed."
|
||||
)
|
||||
.await?;
|
||||
continue;
|
||||
}
|
||||
if field.unique.unwrap_or(false)
|
||||
&& Self::is_duplicate(&item_data, &items, None)
|
||||
{
|
||||
self.show_validation_error(
|
||||
"This item already exists. Duplicates not allowed.",
|
||||
)
|
||||
.await?;
|
||||
continue;
|
||||
}
|
||||
items.push(item_data);
|
||||
selected_index = items.len().saturating_sub(1);
|
||||
@ -1446,14 +1453,18 @@ impl RatatuiBackend {
|
||||
{
|
||||
Ok(updated_data) => {
|
||||
// Check for duplicates if unique constraint is set (exclude current item)
|
||||
if field.unique.unwrap_or(false) {
|
||||
if Self::is_duplicate(&updated_data, &items, Some(selected_index)) {
|
||||
self.show_validation_error(
|
||||
"This item already exists. Duplicates not allowed."
|
||||
)
|
||||
.await?;
|
||||
continue;
|
||||
}
|
||||
if field.unique.unwrap_or(false)
|
||||
&& Self::is_duplicate(
|
||||
&updated_data,
|
||||
&items,
|
||||
Some(selected_index),
|
||||
)
|
||||
{
|
||||
self.show_validation_error(
|
||||
"This item already exists. Duplicates not allowed.",
|
||||
)
|
||||
.await?;
|
||||
continue;
|
||||
}
|
||||
items[selected_index] = updated_data;
|
||||
}
|
||||
@ -1501,7 +1512,7 @@ impl RatatuiBackend {
|
||||
}
|
||||
return Ok(json!(items));
|
||||
}
|
||||
KeyCode::Esc => return Err(Error::cancelled()),
|
||||
KeyCode::Esc => return Err(ErrorWrapper::cancelled_no_context()),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
@ -1530,17 +1541,19 @@ impl RatatuiBackend {
|
||||
.block(Block::default().borders(Borders::ALL).title("Editing Item"));
|
||||
frame.render_widget(paragraph, area);
|
||||
})
|
||||
.map_err(|e| Error::validation_failed(format!("Render failed: {}", e)))?;
|
||||
.map_err(|e| {
|
||||
ErrorWrapper::validation_failed(format!("Render failed: {}", e))
|
||||
})?;
|
||||
}
|
||||
} // terminal_ref dropped here
|
||||
|
||||
// Wait for key press to continue
|
||||
loop {
|
||||
if event::poll(Duration::from_millis(100))
|
||||
.map_err(|e| Error::validation_failed(format!("Poll failed: {}", e)))?
|
||||
.map_err(|e| ErrorWrapper::validation_failed(format!("Poll failed: {}", e)))?
|
||||
{
|
||||
if let Event::Key(_) = event::read()
|
||||
.map_err(|e| Error::validation_failed(format!("Read failed: {}", e)))?
|
||||
.map_err(|e| ErrorWrapper::validation_failed(format!("Read failed: {}", e)))?
|
||||
{
|
||||
break;
|
||||
}
|
||||
@ -1603,7 +1616,7 @@ impl RatatuiBackend {
|
||||
// Check if ALL fields match
|
||||
let all_match = new_item
|
||||
.iter()
|
||||
.all(|(key, value)| existing_item.get(key).map_or(false, |v| v == value));
|
||||
.all(|(key, value)| existing_item.get(key) == Some(value));
|
||||
|
||||
if all_match && !new_item.is_empty() {
|
||||
return true;
|
||||
@ -1627,16 +1640,18 @@ impl RatatuiBackend {
|
||||
|
||||
frame.render_widget(paragraph, area);
|
||||
})
|
||||
.map_err(|e| Error::validation_failed(format!("Failed to render error: {}", e)))?;
|
||||
.map_err(|e| {
|
||||
ErrorWrapper::validation_failed(format!("Failed to render error: {}", e))
|
||||
})?;
|
||||
|
||||
// Wait for any key press
|
||||
loop {
|
||||
if event::poll(Duration::from_millis(100))
|
||||
.map_err(|e| Error::validation_failed(format!("Poll failed: {}", e)))?
|
||||
.map_err(|e| ErrorWrapper::validation_failed(format!("Poll failed: {}", e)))?
|
||||
{
|
||||
if let Event::Key(_) = event::read()
|
||||
.map_err(|e| Error::validation_failed(format!("Read failed: {}", e)))?
|
||||
{
|
||||
if let Event::Key(_) = event::read().map_err(|e| {
|
||||
ErrorWrapper::validation_failed(format!("Read failed: {}", e))
|
||||
})? {
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -1655,7 +1670,7 @@ impl RatatuiBackend {
|
||||
loop {
|
||||
// Disable raw mode and exit alternate screen temporarily
|
||||
disable_raw_mode().map_err(|e| {
|
||||
Error::validation_failed(format!("Failed to disable raw mode: {}", e))
|
||||
ErrorWrapper::validation_failed(format!("Failed to disable raw mode: {}", e))
|
||||
})?;
|
||||
execute!(
|
||||
io::stdout(),
|
||||
@ -1663,7 +1678,7 @@ impl RatatuiBackend {
|
||||
event::DisableMouseCapture
|
||||
)
|
||||
.map_err(|e| {
|
||||
Error::validation_failed(format!("Failed to exit alternate screen: {}", e))
|
||||
ErrorWrapper::validation_failed(format!("Failed to exit alternate screen: {}", e))
|
||||
})?;
|
||||
|
||||
// Create temporary file with timestamp-based name
|
||||
@ -1677,11 +1692,11 @@ impl RatatuiBackend {
|
||||
// Write prefix text if present
|
||||
if let Some(prefix) = &field.prefix_text {
|
||||
fs::write(&temp_file, prefix).map_err(|e| {
|
||||
Error::validation_failed(format!("Failed to write temp file: {}", e))
|
||||
ErrorWrapper::validation_failed(format!("Failed to write temp file: {}", e))
|
||||
})?;
|
||||
} else {
|
||||
fs::write(&temp_file, "").map_err(|e| {
|
||||
Error::validation_failed(format!("Failed to create temp file: {}", e))
|
||||
ErrorWrapper::validation_failed(format!("Failed to create temp file: {}", e))
|
||||
})?;
|
||||
}
|
||||
|
||||
@ -1692,31 +1707,33 @@ impl RatatuiBackend {
|
||||
let status = Command::new(&editor)
|
||||
.arg(&temp_file)
|
||||
.status()
|
||||
.map_err(|e| Error::validation_failed(format!("Failed to launch editor: {}", e)))?;
|
||||
.map_err(|e| {
|
||||
ErrorWrapper::validation_failed(format!("Failed to launch editor: {}", e))
|
||||
})?;
|
||||
|
||||
if !status.success() {
|
||||
let _ = fs::remove_file(&temp_file);
|
||||
drop(fs::remove_file(&temp_file));
|
||||
// Re-enable terminal
|
||||
let _ = enable_raw_mode();
|
||||
let _ = execute!(
|
||||
drop(enable_raw_mode());
|
||||
drop(execute!(
|
||||
io::stdout(),
|
||||
EnterAlternateScreen,
|
||||
event::EnableMouseCapture
|
||||
);
|
||||
return Err(Error::cancelled());
|
||||
));
|
||||
return Err(ErrorWrapper::cancelled_no_context());
|
||||
}
|
||||
|
||||
// Read content from temp file
|
||||
let content = fs::read_to_string(&temp_file).map_err(|e| {
|
||||
Error::validation_failed(format!("Failed to read temp file: {}", e))
|
||||
ErrorWrapper::validation_failed(format!("Failed to read temp file: {}", e))
|
||||
})?;
|
||||
|
||||
// Clean up
|
||||
let _ = fs::remove_file(&temp_file);
|
||||
drop(fs::remove_file(&temp_file));
|
||||
|
||||
// Re-enable raw mode and alternate screen
|
||||
enable_raw_mode().map_err(|e| {
|
||||
Error::validation_failed(format!("Failed to enable raw mode: {}", e))
|
||||
ErrorWrapper::validation_failed(format!("Failed to enable raw mode: {}", e))
|
||||
})?;
|
||||
execute!(
|
||||
io::stdout(),
|
||||
@ -1724,7 +1741,7 @@ impl RatatuiBackend {
|
||||
event::EnableMouseCapture
|
||||
)
|
||||
.map_err(|e| {
|
||||
Error::validation_failed(format!("Failed to enter alternate screen: {}", e))
|
||||
ErrorWrapper::validation_failed(format!("Failed to enter alternate screen: {}", e))
|
||||
})?;
|
||||
|
||||
// Validate required field
|
||||
@ -1797,18 +1814,18 @@ impl RatatuiBackend {
|
||||
|
||||
frame.render_widget(paragraph, area);
|
||||
})
|
||||
.map_err(|e| Error::validation_failed(format!("Failed to render: {}", e)))?;
|
||||
.map_err(|e| ErrorWrapper::validation_failed(format!("Failed to render: {}", e)))?;
|
||||
} else {
|
||||
return Err(Error::validation_failed("Terminal not initialized"));
|
||||
return Err(ErrorWrapper::validation_failed("Terminal not initialized"));
|
||||
}
|
||||
}
|
||||
|
||||
// Event loop
|
||||
if event::poll(Duration::from_millis(100))
|
||||
.map_err(|e| Error::validation_failed(format!("Poll failed: {}", e)))?
|
||||
.map_err(|e| ErrorWrapper::validation_failed(format!("Poll failed: {}", e)))?
|
||||
{
|
||||
if let Event::Key(key) = event::read()
|
||||
.map_err(|e| Error::validation_failed(format!("Read failed: {}", e)))?
|
||||
.map_err(|e| ErrorWrapper::validation_failed(format!("Read failed: {}", e)))?
|
||||
{
|
||||
match key.code {
|
||||
KeyCode::Tab => {
|
||||
@ -1850,7 +1867,7 @@ impl RatatuiBackend {
|
||||
}
|
||||
return Ok(json!(date_str));
|
||||
}
|
||||
KeyCode::Esc => return Err(Error::cancelled()),
|
||||
KeyCode::Esc => return Err(ErrorWrapper::cancelled_no_context()),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
@ -1861,7 +1878,7 @@ impl RatatuiBackend {
|
||||
|
||||
/// Render 3-panel form layout: left (fields), center (input), bottom (buttons)
|
||||
fn render_form_layout(
|
||||
frame: &mut ratatui::Frame,
|
||||
frame: &mut ratatui::Frame<'_>,
|
||||
fields: &[FieldDefinition],
|
||||
results: &std::collections::HashMap<String, Value>,
|
||||
selected_index: usize,
|
||||
@ -2259,49 +2276,49 @@ fn validate_custom_type(input: &str, type_name: &str) -> Result<()> {
|
||||
"i32" => {
|
||||
input
|
||||
.parse::<i32>()
|
||||
.map_err(|_| Error::validation_failed("Expected a 32-bit integer"))?;
|
||||
.map_err(|_| ErrorWrapper::validation_failed("Expected a 32-bit integer"))?;
|
||||
Ok(())
|
||||
}
|
||||
"i64" => {
|
||||
input
|
||||
.parse::<i64>()
|
||||
.map_err(|_| Error::validation_failed("Expected a 64-bit integer"))?;
|
||||
.map_err(|_| ErrorWrapper::validation_failed("Expected a 64-bit integer"))?;
|
||||
Ok(())
|
||||
}
|
||||
"u32" => {
|
||||
input
|
||||
.parse::<u32>()
|
||||
.map_err(|_| Error::validation_failed("Expected an unsigned 32-bit integer"))?;
|
||||
input.parse::<u32>().map_err(|_| {
|
||||
ErrorWrapper::validation_failed("Expected an unsigned 32-bit integer")
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
"u64" => {
|
||||
input
|
||||
.parse::<u64>()
|
||||
.map_err(|_| Error::validation_failed("Expected an unsigned 64-bit integer"))?;
|
||||
input.parse::<u64>().map_err(|_| {
|
||||
ErrorWrapper::validation_failed("Expected an unsigned 64-bit integer")
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
"f32" => {
|
||||
input
|
||||
.parse::<f32>()
|
||||
.map_err(|_| Error::validation_failed("Expected a 32-bit floating point"))?;
|
||||
.map_err(|_| ErrorWrapper::validation_failed("Expected a 32-bit floating point"))?;
|
||||
Ok(())
|
||||
}
|
||||
"f64" => {
|
||||
input
|
||||
.parse::<f64>()
|
||||
.map_err(|_| Error::validation_failed("Expected a 64-bit floating point"))?;
|
||||
.map_err(|_| ErrorWrapper::validation_failed("Expected a 64-bit floating point"))?;
|
||||
Ok(())
|
||||
}
|
||||
"ipv4" => {
|
||||
input.parse::<std::net::Ipv4Addr>().map_err(|_| {
|
||||
Error::validation_failed("Expected valid IPv4 address (e.g., 192.168.1.1)")
|
||||
ErrorWrapper::validation_failed("Expected valid IPv4 address (e.g., 192.168.1.1)")
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
"ipv6" => {
|
||||
input
|
||||
.parse::<std::net::Ipv6Addr>()
|
||||
.map_err(|_| Error::validation_failed("Expected valid IPv6 address"))?;
|
||||
.map_err(|_| ErrorWrapper::validation_failed("Expected valid IPv6 address"))?;
|
||||
Ok(())
|
||||
}
|
||||
"uuid" => {
|
||||
@ -2309,7 +2326,7 @@ fn validate_custom_type(input: &str, type_name: &str) -> Result<()> {
|
||||
if input.len() == 36 && input.matches('-').count() == 4 {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(Error::validation_failed(
|
||||
Err(ErrorWrapper::validation_failed(
|
||||
"Expected UUID format: xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx",
|
||||
))
|
||||
}
|
||||
@ -2317,7 +2334,7 @@ fn validate_custom_type(input: &str, type_name: &str) -> Result<()> {
|
||||
"String" | "str" => Ok(()),
|
||||
"bool" => match input.to_lowercase().as_str() {
|
||||
"true" | "false" | "yes" | "no" | "y" | "n" | "1" | "0" => Ok(()),
|
||||
_ => Err(Error::validation_failed(
|
||||
_ => Err(ErrorWrapper::validation_failed(
|
||||
"Expected: true, false, yes, no, y, n, 1, or 0",
|
||||
)),
|
||||
},
|
||||
@ -2329,26 +2346,26 @@ fn validate_custom_type(input: &str, type_name: &str) -> Result<()> {
|
||||
fn parse_iso_date(date_str: &str) -> Result<(i32, u32, u32)> {
|
||||
let parts: Vec<&str> = date_str.split('-').collect();
|
||||
if parts.len() != 3 {
|
||||
return Err(Error::validation_failed(
|
||||
return Err(ErrorWrapper::validation_failed(
|
||||
"Invalid date format (expected YYYY-MM-DD)",
|
||||
));
|
||||
}
|
||||
|
||||
let year = parts[0]
|
||||
.parse::<i32>()
|
||||
.map_err(|_| Error::validation_failed("Invalid year"))?;
|
||||
.map_err(|_| ErrorWrapper::validation_failed("Invalid year"))?;
|
||||
let month = parts[1]
|
||||
.parse::<u32>()
|
||||
.map_err(|_| Error::validation_failed("Invalid month"))?;
|
||||
.map_err(|_| ErrorWrapper::validation_failed("Invalid month"))?;
|
||||
let day = parts[2]
|
||||
.parse::<u32>()
|
||||
.map_err(|_| Error::validation_failed("Invalid day"))?;
|
||||
.map_err(|_| ErrorWrapper::validation_failed("Invalid day"))?;
|
||||
|
||||
if !(1..=12).contains(&month) {
|
||||
return Err(Error::validation_failed("Month must be 1-12"));
|
||||
return Err(ErrorWrapper::validation_failed("Month must be 1-12"));
|
||||
}
|
||||
if !(1..=31).contains(&day) {
|
||||
return Err(Error::validation_failed("Day must be 1-31"));
|
||||
return Err(ErrorWrapper::validation_failed("Day must be 1-31"));
|
||||
}
|
||||
|
||||
Ok((year, month, day))
|
||||
|
||||
@ -14,7 +14,7 @@ use tokio::sync::{oneshot, RwLock};
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
use super::{FormBackend, RenderContext};
|
||||
use crate::error::{Error, Result};
|
||||
use crate::error::{ErrorWrapper, Result};
|
||||
use crate::form_parser::{DisplayItem, FieldDefinition, FieldType};
|
||||
|
||||
/// Type alias for complete form submission channel
|
||||
@ -126,7 +126,10 @@ impl FormBackend for WebBackend {
|
||||
// Server startup with graceful shutdown
|
||||
let addr = SocketAddr::from(([127, 0, 0, 1], self.port));
|
||||
let listener = tokio::net::TcpListener::bind(addr).await.map_err(|e| {
|
||||
Error::validation_failed(format!("Failed to bind to port {}: {}", self.port, e))
|
||||
ErrorWrapper::validation_failed(format!(
|
||||
"Failed to bind to port {}: {}",
|
||||
self.port, e
|
||||
))
|
||||
})?;
|
||||
|
||||
let server = axum::serve(listener, app).with_graceful_shutdown(async {
|
||||
@ -148,7 +151,7 @@ impl FormBackend for WebBackend {
|
||||
|
||||
#[cfg(not(feature = "web"))]
|
||||
{
|
||||
Err(Error::validation_failed(
|
||||
Err(ErrorWrapper::validation_failed(
|
||||
"Web feature not enabled. Enable with --features web".to_string(),
|
||||
))
|
||||
}
|
||||
@ -249,7 +252,7 @@ impl FormBackend for WebBackend {
|
||||
let state = self
|
||||
.state
|
||||
.as_ref()
|
||||
.ok_or_else(|| Error::validation_failed("Server not initialized"))?;
|
||||
.ok_or_else(|| ErrorWrapper::validation_failed("Server not initialized"))?;
|
||||
|
||||
// Create oneshot channel for field submission
|
||||
let (tx, rx) = oneshot::channel();
|
||||
@ -285,8 +288,8 @@ impl FormBackend for WebBackend {
|
||||
|
||||
Ok(value)
|
||||
}
|
||||
Ok(Err(_)) => Err(Error::cancelled()),
|
||||
Err(_) => Err(Error::validation_failed(
|
||||
Ok(Err(_)) => Err(ErrorWrapper::cancelled_no_context()),
|
||||
Err(_) => Err(ErrorWrapper::validation_failed(
|
||||
"Field submission timeout".to_string(),
|
||||
)),
|
||||
}
|
||||
@ -303,7 +306,7 @@ impl FormBackend for WebBackend {
|
||||
let state = self
|
||||
.state
|
||||
.as_ref()
|
||||
.ok_or_else(|| Error::validation_failed("Server not initialized"))?;
|
||||
.ok_or_else(|| ErrorWrapper::validation_failed("Server not initialized"))?;
|
||||
|
||||
// Initialize results with initial values (for field rendering)
|
||||
{
|
||||
@ -389,8 +392,8 @@ impl FormBackend for WebBackend {
|
||||
|
||||
Ok(all_results)
|
||||
}
|
||||
Ok(Err(_)) => Err(Error::cancelled()),
|
||||
Err(_) => Err(Error::validation_failed(
|
||||
Ok(Err(_)) => Err(ErrorWrapper::cancelled_no_context()),
|
||||
Err(_) => Err(ErrorWrapper::validation_failed(
|
||||
"Form submission timeout".to_string(),
|
||||
)),
|
||||
}
|
||||
@ -404,13 +407,13 @@ impl FormBackend for WebBackend {
|
||||
if let Some(handle) = self.server_handle.take() {
|
||||
match tokio::time::timeout(Duration::from_secs(5), handle).await {
|
||||
Ok(Ok(())) => Ok(()),
|
||||
Ok(Err(e)) => Err(Error::validation_failed(format!(
|
||||
Ok(Err(e)) => Err(ErrorWrapper::validation_failed(format!(
|
||||
"Server join error: {}",
|
||||
e
|
||||
))),
|
||||
Err(_) => {
|
||||
// Timeout - server didn't shutdown gracefully
|
||||
Err(Error::validation_failed(
|
||||
Err(ErrorWrapper::validation_failed(
|
||||
"Server shutdown timeout".to_string(),
|
||||
))
|
||||
}
|
||||
@ -766,7 +769,7 @@ async fn submit_field_handler(
|
||||
{
|
||||
let mut channels = state.field_channels.write().await;
|
||||
if let Some(tx) = channels.remove(&field_name) {
|
||||
let _ = tx.send(value.clone());
|
||||
drop(tx.send(value.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
@ -818,7 +821,7 @@ async fn submit_complete_form_handler(
|
||||
{
|
||||
let mut complete_tx = state.complete_form_tx.write().await;
|
||||
if let Some(tx) = complete_tx.take() {
|
||||
let _ = tx.send(all_results.clone());
|
||||
drop(tx.send(all_results.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
141
crates/typedialog-core/src/config/cli_loader.rs
Normal file
141
crates/typedialog-core/src/config/cli_loader.rs
Normal file
@ -0,0 +1,141 @@
|
||||
//! CLI configuration loader helper for all backends
|
||||
//!
|
||||
//! Provides a unified pattern for loading backend configuration files
|
||||
//! with support for both explicit `-c FILE` and environment-based search.
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
use std::path::Path;
|
||||
|
||||
/// Load backend-specific configuration file
|
||||
///
|
||||
/// If `cli_config_path` is provided, uses that file exclusively.
|
||||
/// Otherwise, searches in order:
|
||||
/// 1. `~/.config/typedialog/{backend_name}/{TYPEDIALOG_ENV}.toml`
|
||||
/// 2. `~/.config/typedialog/{backend_name}/config.toml`
|
||||
/// 3. Returns default value
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `backend_name` - Name of backend (cli, tui, web, ai)
|
||||
/// * `cli_config_path` - Optional explicit config file path from `-c` flag
|
||||
/// * `default` - Default config to return if file not found
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```ignore
|
||||
/// use typedialog_core::config::load_backend_config;
|
||||
/// use std::path::PathBuf;
|
||||
///
|
||||
/// let config = load_backend_config::<MyBackendConfig>(
|
||||
/// "cli",
|
||||
/// Some(PathBuf::from("custom.toml").as_path()),
|
||||
/// MyBackendConfig::default()
|
||||
/// )?;
|
||||
/// ```
|
||||
pub fn load_backend_config<T>(
|
||||
backend_name: &str,
|
||||
cli_config_path: Option<&Path>,
|
||||
default: T,
|
||||
) -> Result<T>
|
||||
where
|
||||
T: serde::de::DeserializeOwned + Default,
|
||||
{
|
||||
// If CLI path provided, use it exclusively
|
||||
if let Some(path) = cli_config_path {
|
||||
return load_from_file(path);
|
||||
}
|
||||
|
||||
// Otherwise try search order
|
||||
let config_dir = dirs::config_dir()
|
||||
.unwrap_or_else(|| {
|
||||
std::path::PathBuf::from(std::env::var("HOME").unwrap_or_else(|_| ".".to_string()))
|
||||
})
|
||||
.join("typedialog")
|
||||
.join(backend_name);
|
||||
|
||||
// Try environment-specific config first
|
||||
let env = std::env::var("TYPEDIALOG_ENV").unwrap_or_else(|_| "default".to_string());
|
||||
let env_config_path = config_dir.join(format!("{}.toml", env));
|
||||
if env_config_path.exists() {
|
||||
if let Ok(config) = load_from_file::<T>(&env_config_path) {
|
||||
return Ok(config);
|
||||
}
|
||||
}
|
||||
|
||||
// Try generic config.toml
|
||||
let generic_config_path = config_dir.join("config.toml");
|
||||
if generic_config_path.exists() {
|
||||
if let Ok(config) = load_from_file::<T>(&generic_config_path) {
|
||||
return Ok(config);
|
||||
}
|
||||
}
|
||||
|
||||
// Return default
|
||||
Ok(default)
|
||||
}
|
||||
|
||||
/// Load configuration from TOML file
|
||||
fn load_from_file<T>(path: &Path) -> Result<T>
|
||||
where
|
||||
T: serde::de::DeserializeOwned,
|
||||
{
|
||||
let content = std::fs::read_to_string(path).map_err(|e| {
|
||||
Error::validation_failed(format!(
|
||||
"Failed to read config file '{}': {}",
|
||||
path.display(),
|
||||
e
|
||||
))
|
||||
})?;
|
||||
|
||||
toml::from_str(&content).map_err(|e| {
|
||||
Error::validation_failed(format!(
|
||||
"Failed to parse config file '{}': {}",
|
||||
path.display(),
|
||||
e
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize, Debug, Default, PartialEq, Clone)]
|
||||
struct TestConfig {
|
||||
name: String,
|
||||
value: i32,
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_with_default() {
|
||||
let default = TestConfig {
|
||||
name: "default".to_string(),
|
||||
value: 42,
|
||||
};
|
||||
let config = load_backend_config::<TestConfig>("test", None, default.clone()).unwrap();
|
||||
// Should return default when no config found
|
||||
assert_eq!(config, default);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_with_explicit_path() {
|
||||
// Create test config file
|
||||
let temp_dir = std::env::temp_dir();
|
||||
let test_config_path = temp_dir.join("test-backend-config.toml");
|
||||
let test_content = r#"
|
||||
name = "test"
|
||||
value = 100
|
||||
"#;
|
||||
std::fs::write(&test_config_path, test_content).ok();
|
||||
|
||||
let default = TestConfig::default();
|
||||
let config =
|
||||
load_backend_config::<TestConfig>("test", Some(test_config_path.as_path()), default)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(config.name, "test");
|
||||
assert_eq!(config.value, 100);
|
||||
|
||||
// Cleanup
|
||||
std::fs::remove_file(test_config_path).ok();
|
||||
}
|
||||
}
|
||||
@ -1,7 +1,7 @@
|
||||
//! Configuration file loader
|
||||
|
||||
use crate::config::TypeDialogConfig;
|
||||
use crate::error::{Error, Result};
|
||||
use crate::error::{ErrorWrapper, Result};
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
|
||||
@ -14,7 +14,7 @@ pub fn load_global_config() -> Result<TypeDialogConfig> {
|
||||
if config_path.exists() {
|
||||
let content = fs::read_to_string(&config_path)?;
|
||||
toml::from_str(&content).map_err(|e| {
|
||||
Error::config_not_found(format!(
|
||||
ErrorWrapper::config_not_found(format!(
|
||||
"Failed to parse config file at {:?}: {}",
|
||||
config_path, e
|
||||
))
|
||||
@ -29,8 +29,9 @@ fn get_config_path() -> Result<PathBuf> {
|
||||
#[cfg(feature = "i18n")]
|
||||
{
|
||||
use dirs::config_dir;
|
||||
let config_dir = config_dir()
|
||||
.ok_or_else(|| Error::config_not_found("Unable to determine config directory"))?;
|
||||
let config_dir = config_dir().ok_or_else(|| {
|
||||
ErrorWrapper::config_not_found("Unable to determine config directory")
|
||||
})?;
|
||||
Ok(config_dir.join("typedialog").join("config.toml"))
|
||||
}
|
||||
|
||||
@ -39,7 +40,7 @@ fn get_config_path() -> Result<PathBuf> {
|
||||
// Fallback without dirs dependency
|
||||
std::env::var("HOME")
|
||||
.or_else(|_| std::env::var("USERPROFILE"))
|
||||
.map_err(|_| Error::config_not_found("Unable to determine home directory"))
|
||||
.map_err(|_| ErrorWrapper::config_not_found("Unable to determine home directory"))
|
||||
.map(|home| PathBuf::from(home).join(".config/typedialog/config.toml"))
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,9 +1,12 @@
|
||||
//! Configuration management for typedialog
|
||||
//!
|
||||
//! Handles global configuration loading and defaults.
|
||||
//! Provides unified backend configuration loading with CLI support.
|
||||
|
||||
pub mod cli_loader;
|
||||
mod loader;
|
||||
|
||||
pub use cli_loader::load_backend_config;
|
||||
pub use loader::load_global_config;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
@ -52,7 +55,7 @@ impl Default for TypeDialogConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
locale: None,
|
||||
locales_path: PathBuf::from("./locales"),
|
||||
locales_path: PathBuf::from("./examples/06-i18n"),
|
||||
templates_path: PathBuf::from("./templates"),
|
||||
fallback_locale: "en-US".to_string(),
|
||||
encryption: None,
|
||||
|
||||
@ -3,8 +3,8 @@
|
||||
//! Converts field encryption configuration to BackendSpec for use with the
|
||||
//! unified encryption API from the `encrypt` crate.
|
||||
|
||||
use crate::form_parser::FieldDefinition;
|
||||
use crate::error::Result;
|
||||
use crate::form_parser::FieldDefinition;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Convert FieldDefinition encryption configuration to BackendSpec.
|
||||
@ -64,9 +64,9 @@ pub fn field_to_backend_spec(
|
||||
.get("vault_addr")
|
||||
.or_else(|| config.get("address"))
|
||||
.ok_or_else(|| {
|
||||
crate::error::Error::new(
|
||||
crate::error::ErrorKind::Other,
|
||||
"SecretumVault backend requires vault_addr in encryption_config".to_string(),
|
||||
crate::error::ErrorWrapper::new(
|
||||
"SecretumVault backend requires vault_addr in encryption_config"
|
||||
.to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
@ -74,9 +74,9 @@ pub fn field_to_backend_spec(
|
||||
.get("vault_token")
|
||||
.or_else(|| config.get("token"))
|
||||
.ok_or_else(|| {
|
||||
crate::error::Error::new(
|
||||
crate::error::ErrorKind::Other,
|
||||
"SecretumVault backend requires vault_token in encryption_config".to_string(),
|
||||
crate::error::ErrorWrapper::new(
|
||||
"SecretumVault backend requires vault_token in encryption_config"
|
||||
.to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
@ -93,22 +93,19 @@ pub fn field_to_backend_spec(
|
||||
))
|
||||
}
|
||||
"awskms" => {
|
||||
let region = config
|
||||
.get("region")
|
||||
.ok_or_else(|| {
|
||||
crate::error::Error::new(
|
||||
crate::error::ErrorKind::Other,
|
||||
"AWS KMS backend requires region in encryption_config".to_string(),
|
||||
)
|
||||
})?;
|
||||
let region = config.get("region").ok_or_else(|| {
|
||||
crate::error::ErrorWrapper::new(
|
||||
"AWS KMS backend requires region in encryption_config".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
let key_id = config
|
||||
.get("key_id")
|
||||
.or_else(|| config.get("key_arn"))
|
||||
.ok_or_else(|| {
|
||||
crate::error::Error::new(
|
||||
crate::error::ErrorKind::Other,
|
||||
"AWS KMS backend requires key_id or key_arn in encryption_config".to_string(),
|
||||
crate::error::ErrorWrapper::new(
|
||||
"AWS KMS backend requires key_id or key_arn in encryption_config"
|
||||
.to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
@ -118,30 +115,23 @@ pub fn field_to_backend_spec(
|
||||
))
|
||||
}
|
||||
"gcpkms" => {
|
||||
let project_id = config
|
||||
.get("project_id")
|
||||
.ok_or_else(|| {
|
||||
crate::error::Error::new(
|
||||
crate::error::ErrorKind::Other,
|
||||
"GCP KMS backend requires project_id in encryption_config".to_string(),
|
||||
)
|
||||
})?;
|
||||
let project_id = config.get("project_id").ok_or_else(|| {
|
||||
crate::error::ErrorWrapper::new(
|
||||
"GCP KMS backend requires project_id in encryption_config".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
let key_ring = config
|
||||
.get("key_ring")
|
||||
.ok_or_else(|| {
|
||||
crate::error::Error::new(
|
||||
crate::error::ErrorKind::Other,
|
||||
"GCP KMS backend requires key_ring in encryption_config".to_string(),
|
||||
)
|
||||
})?;
|
||||
let key_ring = config.get("key_ring").ok_or_else(|| {
|
||||
crate::error::ErrorWrapper::new(
|
||||
"GCP KMS backend requires key_ring in encryption_config".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
let crypto_key = config
|
||||
.get("crypto_key")
|
||||
.or_else(|| config.get("key"))
|
||||
.ok_or_else(|| {
|
||||
crate::error::Error::new(
|
||||
crate::error::ErrorKind::Other,
|
||||
crate::error::ErrorWrapper::new(
|
||||
"GCP KMS backend requires crypto_key in encryption_config".to_string(),
|
||||
)
|
||||
})?;
|
||||
@ -159,33 +149,27 @@ pub fn field_to_backend_spec(
|
||||
))
|
||||
}
|
||||
"azurekms" => {
|
||||
let vault_name = config
|
||||
.get("vault_name")
|
||||
.ok_or_else(|| {
|
||||
crate::error::Error::new(
|
||||
crate::error::ErrorKind::Other,
|
||||
"Azure KMS backend requires vault_name in encryption_config".to_string(),
|
||||
)
|
||||
})?;
|
||||
let vault_name = config.get("vault_name").ok_or_else(|| {
|
||||
crate::error::ErrorWrapper::new(
|
||||
"Azure KMS backend requires vault_name in encryption_config".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
let tenant_id = config
|
||||
.get("tenant_id")
|
||||
.ok_or_else(|| {
|
||||
crate::error::Error::new(
|
||||
crate::error::ErrorKind::Other,
|
||||
"Azure KMS backend requires tenant_id in encryption_config".to_string(),
|
||||
)
|
||||
})?;
|
||||
let tenant_id = config.get("tenant_id").ok_or_else(|| {
|
||||
crate::error::ErrorWrapper::new(
|
||||
"Azure KMS backend requires tenant_id in encryption_config".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(encrypt::BackendSpec::azure_kms(
|
||||
vault_name.clone(),
|
||||
tenant_id.clone(),
|
||||
))
|
||||
}
|
||||
backend => Err(crate::error::Error::new(
|
||||
crate::error::ErrorKind::ValidationFailed,
|
||||
format!("Unknown encryption backend: {}", backend),
|
||||
)),
|
||||
backend => Err(crate::error::ErrorWrapper::new(format!(
|
||||
"Unknown encryption backend: {}",
|
||||
backend
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1,185 +1,673 @@
|
||||
//! Error handling for typedialog
|
||||
//! Error handling for typedialog - Canonical Error Structs
|
||||
//!
|
||||
//! Provides structured error types for all operations.
|
||||
//! Provides structured error types for all operations following M-ERRORS-CANONICAL-STRUCTS.
|
||||
//! Each error type is a concrete struct with specific error kind enum.
|
||||
|
||||
#![allow(clippy::result_large_err)]
|
||||
|
||||
use std::fmt;
|
||||
use std::io;
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// Errors that can occur during form operations
|
||||
// ============================================================================
|
||||
// CANCELLED ERROR
|
||||
// ============================================================================
|
||||
|
||||
/// User cancelled the operation
|
||||
#[derive(Debug)]
|
||||
pub struct Error {
|
||||
kind: ErrorKind,
|
||||
message: String,
|
||||
pub struct CancelledError {
|
||||
pub context: String,
|
||||
}
|
||||
|
||||
/// Error kinds for form operations
|
||||
#[derive(Debug)]
|
||||
pub enum ErrorKind {
|
||||
/// User cancelled the prompt
|
||||
Cancelled,
|
||||
/// Form parsing failed
|
||||
FormParseFailed,
|
||||
/// I/O error
|
||||
Io,
|
||||
/// TOML parsing error
|
||||
TomlParse,
|
||||
/// Validation failed
|
||||
ValidationFailed,
|
||||
/// i18n (internationalization) error
|
||||
I18nFailed,
|
||||
/// Template error
|
||||
TemplateFailed,
|
||||
/// Configuration error
|
||||
ConfigNotFound,
|
||||
/// Other errors
|
||||
Other,
|
||||
impl fmt::Display for CancelledError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"Operation cancelled{}",
|
||||
if self.context.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!(": {}", self.context)
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Error {
|
||||
/// Create a new error with a specific kind and message
|
||||
pub fn new(kind: ErrorKind, message: impl Into<String>) -> Self {
|
||||
Self {
|
||||
kind,
|
||||
message: message.into(),
|
||||
impl std::error::Error for CancelledError {}
|
||||
|
||||
// ============================================================================
|
||||
// FORM PARSE ERROR
|
||||
// ============================================================================
|
||||
|
||||
/// Form parsing failed
|
||||
#[derive(Debug)]
|
||||
pub struct FormParseError {
|
||||
pub kind: FormParseErrorKind,
|
||||
pub message: String,
|
||||
pub source: Option<Box<dyn std::error::Error + Send + Sync>>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum FormParseErrorKind {
|
||||
InvalidToml {
|
||||
line: usize,
|
||||
column: usize,
|
||||
},
|
||||
MissingField {
|
||||
field: String,
|
||||
},
|
||||
InvalidFieldType {
|
||||
field: String,
|
||||
expected: String,
|
||||
got: String,
|
||||
},
|
||||
InvalidOption {
|
||||
field: String,
|
||||
option: String,
|
||||
},
|
||||
MissingRepeatingGroupFragment {
|
||||
field: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl FormParseError {
|
||||
pub fn is_missing_field(&self) -> bool {
|
||||
matches!(self.kind, FormParseErrorKind::MissingField { .. })
|
||||
}
|
||||
|
||||
pub fn missing_field_name(&self) -> Option<&str> {
|
||||
if let FormParseErrorKind::MissingField { field } = &self.kind {
|
||||
Some(field)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a cancelled error
|
||||
pub fn cancelled() -> Self {
|
||||
Self::new(ErrorKind::Cancelled, "Operation cancelled")
|
||||
}
|
||||
|
||||
/// Create a form parse error
|
||||
pub fn form_parse_failed(msg: impl Into<String>) -> Self {
|
||||
Self::new(ErrorKind::FormParseFailed, msg)
|
||||
}
|
||||
|
||||
/// Create an I/O error
|
||||
pub fn io(source: io::Error) -> Self {
|
||||
Self::new(ErrorKind::Io, format!("I/O error: {}", source))
|
||||
}
|
||||
|
||||
/// Create a TOML parse error
|
||||
pub fn toml_parse(source: toml::de::Error) -> Self {
|
||||
Self::new(
|
||||
ErrorKind::TomlParse,
|
||||
format!("TOML parse error: {}", source),
|
||||
)
|
||||
}
|
||||
|
||||
/// Create a validation error
|
||||
pub fn validation_failed(msg: impl Into<String>) -> Self {
|
||||
Self::new(ErrorKind::ValidationFailed, msg)
|
||||
}
|
||||
|
||||
/// Create an i18n error
|
||||
pub fn i18n_failed(msg: impl Into<String>) -> Self {
|
||||
Self::new(ErrorKind::I18nFailed, msg)
|
||||
}
|
||||
|
||||
/// Create a template error
|
||||
pub fn template_failed(msg: impl Into<String>) -> Self {
|
||||
Self::new(ErrorKind::TemplateFailed, msg)
|
||||
}
|
||||
|
||||
/// Create a config error
|
||||
pub fn config_not_found(msg: impl Into<String>) -> Self {
|
||||
Self::new(ErrorKind::ConfigNotFound, msg)
|
||||
}
|
||||
|
||||
/// Get the error kind
|
||||
pub fn kind(&self) -> &ErrorKind {
|
||||
&self.kind
|
||||
}
|
||||
|
||||
/// Get the error message
|
||||
pub fn message(&self) -> &str {
|
||||
&self.message
|
||||
}
|
||||
|
||||
/// Check if this is a cancellation error
|
||||
pub fn is_cancelled(&self) -> bool {
|
||||
matches!(self.kind, ErrorKind::Cancelled)
|
||||
}
|
||||
|
||||
/// Check if this is a parse error
|
||||
pub fn is_parse_error(&self) -> bool {
|
||||
matches!(self.kind, ErrorKind::FormParseFailed | ErrorKind::TomlParse)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Error {
|
||||
impl fmt::Display for FormParseError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.message)
|
||||
match &self.kind {
|
||||
FormParseErrorKind::InvalidToml { line, column } => {
|
||||
write!(
|
||||
f,
|
||||
"TOML parse error at line {}, column {}: {}",
|
||||
line, column, self.message
|
||||
)
|
||||
}
|
||||
FormParseErrorKind::MissingField { field } => {
|
||||
write!(f, "Missing field '{}': {}", field, self.message)
|
||||
}
|
||||
FormParseErrorKind::InvalidFieldType {
|
||||
field,
|
||||
expected,
|
||||
got,
|
||||
} => {
|
||||
write!(f, "Field '{}': expected {}, got {}", field, expected, got)
|
||||
}
|
||||
FormParseErrorKind::InvalidOption { field, option } => {
|
||||
write!(f, "Field '{}': invalid option '{}'", field, option)
|
||||
}
|
||||
FormParseErrorKind::MissingRepeatingGroupFragment { field } => {
|
||||
write!(f, "RepeatingGroup '{}': missing fragment template", field)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for Error {}
|
||||
|
||||
impl From<io::Error> for Error {
|
||||
fn from(err: io::Error) -> Self {
|
||||
Self::io(err)
|
||||
impl std::error::Error for FormParseError {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
self.source
|
||||
.as_ref()
|
||||
.map(|e| e.as_ref() as &dyn std::error::Error)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<toml::de::Error> for Error {
|
||||
fn from(err: toml::de::Error) -> Self {
|
||||
Self::toml_parse(err)
|
||||
// ============================================================================
|
||||
// I/O ERROR
|
||||
// ============================================================================
|
||||
|
||||
/// I/O operation failed
|
||||
#[derive(Debug)]
|
||||
pub struct IoError {
|
||||
pub operation: String,
|
||||
pub path: Option<PathBuf>,
|
||||
pub source: io::Error,
|
||||
}
|
||||
|
||||
impl fmt::Display for IoError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match &self.path {
|
||||
Some(p) => write!(
|
||||
f,
|
||||
"I/O error ({}) on '{}': {}",
|
||||
self.operation,
|
||||
p.display(),
|
||||
self.source
|
||||
),
|
||||
None => write!(f, "I/O error ({}): {}", self.operation, self.source),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<serde_json::Error> for Error {
|
||||
fn from(err: serde_json::Error) -> Self {
|
||||
Self::new(ErrorKind::Other, format!("JSON error: {}", err))
|
||||
impl std::error::Error for IoError {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
Some(&self.source)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<serde_yaml::Error> for Error {
|
||||
fn from(err: serde_yaml::Error) -> Self {
|
||||
Self::new(ErrorKind::Other, format!("YAML error: {}", err))
|
||||
// ============================================================================
|
||||
// VALIDATION ERROR
|
||||
// ============================================================================
|
||||
|
||||
/// Validation failed
|
||||
#[derive(Debug)]
|
||||
pub struct ValidationError {
|
||||
pub kind: ValidationErrorKind,
|
||||
pub field: String,
|
||||
pub value: Option<serde_json::Value>,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ValidationErrorKind {
|
||||
RequiredFieldMissing,
|
||||
ContractViolation { contract: String, reason: String },
|
||||
TypeMismatch { expected: String, got: String },
|
||||
RangeViolation { min: Option<f64>, max: Option<f64> },
|
||||
InvalidDate { format: String },
|
||||
}
|
||||
|
||||
impl ValidationError {
|
||||
pub fn is_required_field(&self) -> bool {
|
||||
matches!(self.kind, ValidationErrorKind::RequiredFieldMissing)
|
||||
}
|
||||
|
||||
pub fn is_contract_violation(&self) -> bool {
|
||||
matches!(self.kind, ValidationErrorKind::ContractViolation { .. })
|
||||
}
|
||||
|
||||
pub fn contract_details(&self) -> Option<(&str, &str)> {
|
||||
if let ValidationErrorKind::ContractViolation { contract, reason } = &self.kind {
|
||||
Some((contract, reason))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<chrono::ParseError> for Error {
|
||||
fn from(err: chrono::ParseError) -> Self {
|
||||
Self::new(
|
||||
ErrorKind::ValidationFailed,
|
||||
format!("Date parsing error: {}", err),
|
||||
impl fmt::Display for ValidationError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"Validation error on field '{}': {}",
|
||||
self.field,
|
||||
match &self.kind {
|
||||
ValidationErrorKind::RequiredFieldMissing => "field is required".to_string(),
|
||||
ValidationErrorKind::ContractViolation { contract, reason } =>
|
||||
format!("contract '{}' violated ({})", contract, reason),
|
||||
ValidationErrorKind::TypeMismatch { expected, got } =>
|
||||
format!("expected {}, got {}", expected, got),
|
||||
ValidationErrorKind::RangeViolation { min, max } => match (min, max) {
|
||||
(Some(min_val), Some(max_val)) =>
|
||||
format!("value must be between {} and {}", min_val, max_val),
|
||||
(Some(min_val), None) => format!("value must be >= {}", min_val),
|
||||
(None, Some(max_val)) => format!("value must be <= {}", max_val),
|
||||
(None, None) => "value out of range".to_string(),
|
||||
},
|
||||
ValidationErrorKind::InvalidDate { format } =>
|
||||
format!("invalid date format: {}", format),
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<inquire::InquireError> for Error {
|
||||
impl std::error::Error for ValidationError {}
|
||||
|
||||
// ============================================================================
|
||||
// I18N ERROR
|
||||
// ============================================================================
|
||||
|
||||
/// Internationalization error
|
||||
#[derive(Debug)]
|
||||
pub struct I18nError {
|
||||
pub kind: I18nErrorKind,
|
||||
pub locale: Option<String>,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum I18nErrorKind {
|
||||
LoadFailed { path: PathBuf },
|
||||
MessageNotFound { key: String },
|
||||
InvalidSyntax { line: usize },
|
||||
}
|
||||
|
||||
impl fmt::Display for I18nError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let locale_str = self
|
||||
.locale
|
||||
.as_ref()
|
||||
.map(|l| format!(" (locale: {})", l))
|
||||
.unwrap_or_default();
|
||||
|
||||
match &self.kind {
|
||||
I18nErrorKind::LoadFailed { path } => write!(
|
||||
f,
|
||||
"Failed to load i18n from '{}'{}: {}",
|
||||
path.display(),
|
||||
locale_str,
|
||||
self.message
|
||||
),
|
||||
I18nErrorKind::MessageNotFound { key } => {
|
||||
write!(f, "Message not found: '{}'{}", key, locale_str)
|
||||
}
|
||||
I18nErrorKind::InvalidSyntax { line } => write!(
|
||||
f,
|
||||
"Invalid i18n syntax at line {}{}: {}",
|
||||
line, locale_str, self.message
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for I18nError {}
|
||||
|
||||
// ============================================================================
|
||||
// TEMPLATE ERROR
|
||||
// ============================================================================
|
||||
|
||||
/// Template processing failed
|
||||
#[derive(Debug)]
|
||||
pub struct TemplateError {
|
||||
pub kind: TemplateErrorKind,
|
||||
pub template_name: Option<String>,
|
||||
pub message: String,
|
||||
pub source: Option<Box<dyn std::error::Error + Send + Sync>>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum TemplateErrorKind {
|
||||
ParseFailed { line: usize, column: usize },
|
||||
RenderFailed { variable: String },
|
||||
VariableNotFound { name: String },
|
||||
}
|
||||
|
||||
impl fmt::Display for TemplateError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let template = self
|
||||
.template_name
|
||||
.as_ref()
|
||||
.map(|t| format!(" in '{}'", t))
|
||||
.unwrap_or_default();
|
||||
|
||||
match &self.kind {
|
||||
TemplateErrorKind::ParseFailed { line, column } => write!(
|
||||
f,
|
||||
"Template parse error at line {}, column {}{}: {}",
|
||||
line, column, template, self.message
|
||||
),
|
||||
TemplateErrorKind::RenderFailed { variable } => write!(
|
||||
f,
|
||||
"Failed to render variable '{}'{}:{}",
|
||||
variable, template, self.message
|
||||
),
|
||||
TemplateErrorKind::VariableNotFound { name } => {
|
||||
write!(f, "Variable '{}' not found{}", name, template)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for TemplateError {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
self.source
|
||||
.as_ref()
|
||||
.map(|e| e.as_ref() as &dyn std::error::Error)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// CONFIG NOT FOUND ERROR
|
||||
// ============================================================================
|
||||
|
||||
/// Configuration file not found
|
||||
#[derive(Debug)]
|
||||
pub struct ConfigNotFoundError {
|
||||
pub search_paths: Vec<PathBuf>,
|
||||
pub config_name: String,
|
||||
}
|
||||
|
||||
impl fmt::Display for ConfigNotFoundError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"Configuration '{}' not found in: {}",
|
||||
self.config_name,
|
||||
self.search_paths
|
||||
.iter()
|
||||
.map(|p| p.display().to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for ConfigNotFoundError {}
|
||||
|
||||
// ============================================================================
|
||||
// UNIFIED ERROR WRAPPER (for public API boundaries)
|
||||
// ============================================================================
|
||||
|
||||
/// Unified error type wrapping all specific error types
|
||||
#[derive(Debug)]
|
||||
pub enum ErrorWrapper {
|
||||
Cancelled(CancelledError),
|
||||
FormParse(FormParseError),
|
||||
Io(IoError),
|
||||
Validation(ValidationError),
|
||||
I18n(I18nError),
|
||||
Template(TemplateError),
|
||||
ConfigNotFound(ConfigNotFoundError),
|
||||
}
|
||||
|
||||
impl fmt::Display for ErrorWrapper {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
ErrorWrapper::Cancelled(e) => write!(f, "{}", e),
|
||||
ErrorWrapper::FormParse(e) => write!(f, "{}", e),
|
||||
ErrorWrapper::Io(e) => write!(f, "{}", e),
|
||||
ErrorWrapper::Validation(e) => write!(f, "{}", e),
|
||||
ErrorWrapper::I18n(e) => write!(f, "{}", e),
|
||||
ErrorWrapper::Template(e) => write!(f, "{}", e),
|
||||
ErrorWrapper::ConfigNotFound(e) => write!(f, "{}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for ErrorWrapper {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
match self {
|
||||
ErrorWrapper::FormParse(e) => e.source(),
|
||||
ErrorWrapper::Io(e) => e.source(),
|
||||
ErrorWrapper::Template(e) => e.source(),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// FROM IMPLEMENTATIONS
|
||||
// ============================================================================
|
||||
|
||||
impl From<io::Error> for ErrorWrapper {
|
||||
fn from(err: io::Error) -> Self {
|
||||
ErrorWrapper::Io(IoError {
|
||||
operation: "unknown".into(),
|
||||
path: None,
|
||||
source: err,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl From<io::Error> for IoError {
|
||||
fn from(err: io::Error) -> Self {
|
||||
IoError {
|
||||
operation: "unknown".into(),
|
||||
path: None,
|
||||
source: err,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<toml::de::Error> for ErrorWrapper {
|
||||
fn from(err: toml::de::Error) -> Self {
|
||||
// toml 0.9 doesn't expose line/column in error, use defaults
|
||||
ErrorWrapper::FormParse(FormParseError {
|
||||
kind: FormParseErrorKind::InvalidToml { line: 0, column: 0 },
|
||||
message: err.to_string(),
|
||||
source: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl From<toml::de::Error> for FormParseError {
|
||||
fn from(err: toml::de::Error) -> Self {
|
||||
// toml 0.9 doesn't expose line/column in error, use defaults
|
||||
FormParseError {
|
||||
kind: FormParseErrorKind::InvalidToml { line: 0, column: 0 },
|
||||
message: err.to_string(),
|
||||
source: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<serde_json::Error> for ErrorWrapper {
|
||||
fn from(err: serde_json::Error) -> Self {
|
||||
ErrorWrapper::Validation(ValidationError {
|
||||
kind: ValidationErrorKind::TypeMismatch {
|
||||
expected: "valid JSON".into(),
|
||||
got: format!("JSON error: {}", err),
|
||||
},
|
||||
field: "json".into(),
|
||||
value: None,
|
||||
message: err.to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl From<serde_yaml::Error> for ErrorWrapper {
|
||||
fn from(err: serde_yaml::Error) -> Self {
|
||||
ErrorWrapper::Validation(ValidationError {
|
||||
kind: ValidationErrorKind::TypeMismatch {
|
||||
expected: "valid YAML".into(),
|
||||
got: format!("YAML error: {}", err),
|
||||
},
|
||||
field: "yaml".into(),
|
||||
value: None,
|
||||
message: err.to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl From<chrono::ParseError> for ErrorWrapper {
|
||||
fn from(err: chrono::ParseError) -> Self {
|
||||
ErrorWrapper::Validation(ValidationError {
|
||||
kind: ValidationErrorKind::InvalidDate {
|
||||
format: "see chrono documentation".into(),
|
||||
},
|
||||
field: "date".into(),
|
||||
value: None,
|
||||
message: err.to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl From<inquire::InquireError> for ErrorWrapper {
|
||||
fn from(err: inquire::InquireError) -> Self {
|
||||
match err {
|
||||
inquire::InquireError::OperationCanceled => Self::cancelled(),
|
||||
_ => Self::new(ErrorKind::Other, format!("Prompt error: {}", err)),
|
||||
inquire::InquireError::OperationCanceled => ErrorWrapper::Cancelled(CancelledError {
|
||||
context: String::new(),
|
||||
}),
|
||||
_ => ErrorWrapper::Validation(ValidationError {
|
||||
kind: ValidationErrorKind::TypeMismatch {
|
||||
expected: "valid input".into(),
|
||||
got: "prompt error".into(),
|
||||
},
|
||||
field: "prompt".into(),
|
||||
value: None,
|
||||
message: err.to_string(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Result type for typedialog operations
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
// ============================================================================
|
||||
// PUBLIC RESULT TYPE
|
||||
// ============================================================================
|
||||
|
||||
pub type Result<T> = std::result::Result<T, ErrorWrapper>;
|
||||
|
||||
/// Error type alias for convenient use
|
||||
pub type Error = ErrorWrapper;
|
||||
|
||||
// ============================================================================
|
||||
// HELPER CONSTRUCTORS (for migration compatibility)
|
||||
// ============================================================================
|
||||
|
||||
impl ErrorWrapper {
|
||||
pub fn cancelled(context: impl Into<String>) -> Self {
|
||||
ErrorWrapper::Cancelled(CancelledError {
|
||||
context: context.into(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn cancelled_no_context() -> Self {
|
||||
ErrorWrapper::Cancelled(CancelledError {
|
||||
context: String::new(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn form_parse_failed(msg: impl Into<String>) -> Self {
|
||||
ErrorWrapper::FormParse(FormParseError {
|
||||
kind: FormParseErrorKind::InvalidToml { line: 0, column: 0 },
|
||||
message: msg.into(),
|
||||
source: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Helper for converting io::Error to IoError (supports .map_err(ErrorWrapper::io))
|
||||
pub fn io(source: io::Error) -> Self {
|
||||
ErrorWrapper::Io(IoError {
|
||||
operation: "unknown".into(),
|
||||
path: None,
|
||||
source,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn validation_failed(msg: impl Into<String>) -> Self {
|
||||
let msg_str = msg.into();
|
||||
ErrorWrapper::Validation(ValidationError {
|
||||
kind: ValidationErrorKind::TypeMismatch {
|
||||
expected: "valid input".into(),
|
||||
got: msg_str.clone(),
|
||||
},
|
||||
field: "input".into(),
|
||||
value: None,
|
||||
message: msg_str,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn validation_failed_field(field: impl Into<String>, msg: impl Into<String>) -> Self {
|
||||
let msg_str = msg.into();
|
||||
ErrorWrapper::Validation(ValidationError {
|
||||
kind: ValidationErrorKind::TypeMismatch {
|
||||
expected: "valid input".into(),
|
||||
got: msg_str.clone(),
|
||||
},
|
||||
field: field.into(),
|
||||
value: None,
|
||||
message: msg_str,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn i18n_failed(msg: impl Into<String>) -> Self {
|
||||
ErrorWrapper::I18n(I18nError {
|
||||
kind: I18nErrorKind::LoadFailed {
|
||||
path: PathBuf::new(),
|
||||
},
|
||||
locale: None,
|
||||
message: msg.into(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn i18n_failed_locale(locale: impl Into<String>, msg: impl Into<String>) -> Self {
|
||||
ErrorWrapper::I18n(I18nError {
|
||||
kind: I18nErrorKind::LoadFailed {
|
||||
path: PathBuf::new(),
|
||||
},
|
||||
locale: Some(locale.into()),
|
||||
message: msg.into(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn template_failed(msg: impl Into<String>) -> Self {
|
||||
ErrorWrapper::Template(TemplateError {
|
||||
kind: TemplateErrorKind::ParseFailed { line: 0, column: 0 },
|
||||
template_name: None,
|
||||
message: msg.into(),
|
||||
source: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn template_failed_named(template: impl Into<String>, msg: impl Into<String>) -> Self {
|
||||
ErrorWrapper::Template(TemplateError {
|
||||
kind: TemplateErrorKind::ParseFailed { line: 0, column: 0 },
|
||||
template_name: Some(template.into()),
|
||||
message: msg.into(),
|
||||
source: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn config_not_found(config_name: impl Into<String>) -> Self {
|
||||
ErrorWrapper::ConfigNotFound(ConfigNotFoundError {
|
||||
search_paths: Vec::new(),
|
||||
config_name: config_name.into(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn is_cancelled(&self) -> bool {
|
||||
matches!(self, ErrorWrapper::Cancelled(_))
|
||||
}
|
||||
|
||||
/// Generic error constructor for migration (treats as form parse error)
|
||||
pub fn new(msg: impl Into<String>) -> Self {
|
||||
ErrorWrapper::form_parse_failed(msg)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_error_display() {
|
||||
let err = Error::validation_failed("test error");
|
||||
assert_eq!(err.to_string(), "test error");
|
||||
fn test_cancelled_error() {
|
||||
let err = CancelledError {
|
||||
context: "user click".into(),
|
||||
};
|
||||
assert_eq!(err.to_string(), "Operation cancelled: user click");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_kind() {
|
||||
let err = Error::cancelled();
|
||||
fn test_form_parse_error() {
|
||||
let err = FormParseError {
|
||||
kind: FormParseErrorKind::InvalidToml { line: 1, column: 5 },
|
||||
message: "unexpected token".into(),
|
||||
source: None,
|
||||
};
|
||||
assert!(err.to_string().contains("TOML parse error"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validation_error_required() {
|
||||
let err = ValidationError {
|
||||
kind: ValidationErrorKind::RequiredFieldMissing,
|
||||
field: "name".into(),
|
||||
value: None,
|
||||
message: "required".into(),
|
||||
};
|
||||
assert!(err.is_required_field());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_wrapper_cancelled() {
|
||||
let err = ErrorWrapper::cancelled("test");
|
||||
assert!(err.is_cancelled());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_error() {
|
||||
let err = Error::form_parse_failed("parse failed");
|
||||
assert!(err.is_parse_error());
|
||||
fn test_from_io_error() {
|
||||
let io_err = io::Error::new(io::ErrorKind::NotFound, "file not found");
|
||||
let err: ErrorWrapper = io_err.into();
|
||||
assert!(matches!(err, ErrorWrapper::Io(_)));
|
||||
}
|
||||
}
|
||||
|
||||
@ -16,7 +16,7 @@ fn default_order() -> usize {
|
||||
|
||||
/// Form element (can be a display item or a field)
|
||||
/// Public enum for unified form structure
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum FormElement {
|
||||
Item(DisplayItem),
|
||||
Field(FieldDefinition),
|
||||
@ -94,7 +94,7 @@ impl<'de> Deserialize<'de> for FormElement {
|
||||
impl<'de> de::Visitor<'de> for ElementVisitor {
|
||||
type Value = FormElement;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
formatter.write_str("a FormElement with a type field")
|
||||
}
|
||||
|
||||
@ -167,6 +167,18 @@ impl<'de> Deserialize<'de> for FormElement {
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for FormElement {
|
||||
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
match self {
|
||||
FormElement::Item(item) => item.serialize(serializer),
|
||||
FormElement::Field(field) => field.serialize(serializer),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A display item (header, section, CTA, footer, etc.)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DisplayItem {
|
||||
@ -599,10 +611,7 @@ pub enum DisplayMode {
|
||||
/// Resolve constraint interpolations in TOML content
|
||||
/// Replaces "${constraint.path.to.value}" (with quotes) with actual values from constraints.toml
|
||||
/// The quotes are removed as part of the replacement, so the value becomes a bare number
|
||||
fn resolve_constraints_in_content(
|
||||
content: &str,
|
||||
base_dir: &Path,
|
||||
) -> Result<String> {
|
||||
fn resolve_constraints_in_content(content: &str, base_dir: &Path) -> Result<String> {
|
||||
let constraints_path = base_dir.join("constraints.toml");
|
||||
|
||||
// If constraints.toml doesn't exist, return content unchanged
|
||||
@ -612,7 +621,10 @@ fn resolve_constraints_in_content(
|
||||
|
||||
let constraints_content = std::fs::read_to_string(&constraints_path)?;
|
||||
let constraints_table: toml::Table = toml::from_str(&constraints_content).map_err(|e| {
|
||||
crate::error::Error::validation_failed(format!("Failed to parse constraints.toml: {}", e))
|
||||
crate::error::ErrorWrapper::validation_failed(format!(
|
||||
"Failed to parse constraints.toml: {}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
|
||||
let mut result = content.to_string();
|
||||
@ -1341,7 +1353,7 @@ fn execute_field(
|
||||
|
||||
FieldType::Select => {
|
||||
if field.options.is_empty() {
|
||||
return Err(crate::Error::form_parse_failed(
|
||||
return Err(crate::ErrorWrapper::form_parse_failed(
|
||||
"Select field requires 'options'",
|
||||
));
|
||||
}
|
||||
@ -1362,7 +1374,7 @@ fn execute_field(
|
||||
|
||||
FieldType::MultiSelect => {
|
||||
if field.options.is_empty() {
|
||||
return Err(crate::Error::form_parse_failed(
|
||||
return Err(crate::ErrorWrapper::form_parse_failed(
|
||||
"MultiSelect field requires 'options'",
|
||||
));
|
||||
}
|
||||
@ -1417,7 +1429,7 @@ fn execute_field(
|
||||
FieldType::Custom => {
|
||||
let prompt_with_marker = format!("{}{}", field.prompt, required_marker);
|
||||
let type_name = field.custom_type.as_ref().ok_or_else(|| {
|
||||
crate::Error::form_parse_failed("Custom field requires 'custom_type'")
|
||||
crate::ErrorWrapper::form_parse_failed("Custom field requires 'custom_type'")
|
||||
})?;
|
||||
let result = prompts::custom(&prompt_with_marker, type_name, field.default.as_deref())?;
|
||||
|
||||
@ -1430,7 +1442,7 @@ fn execute_field(
|
||||
|
||||
FieldType::RepeatingGroup => {
|
||||
// Temporary stub - will be implemented in FASE 4
|
||||
Err(crate::Error::form_parse_failed(
|
||||
Err(crate::ErrorWrapper::form_parse_failed(
|
||||
"RepeatingGroup not yet implemented - use CLI backend (FASE 4)",
|
||||
))
|
||||
}
|
||||
@ -3002,7 +3014,7 @@ mod integration_tests {
|
||||
|
||||
// Verify results
|
||||
assert_eq!(
|
||||
results.get("account_type").map(|v| v.as_str()).flatten(),
|
||||
results.get("account_type").and_then(|v| v.as_str()),
|
||||
Some("basic")
|
||||
);
|
||||
}
|
||||
@ -3044,7 +3056,7 @@ mod integration_tests {
|
||||
// When enable_feature is true, feature_config should be executed
|
||||
assert!(executed.contains(&"feature_config".to_string()));
|
||||
assert_eq!(
|
||||
results.get("feature_config").map(|v| v.as_str()).flatten(),
|
||||
results.get("feature_config").and_then(|v| v.as_str()),
|
||||
Some("custom_config")
|
||||
);
|
||||
}
|
||||
@ -3165,11 +3177,11 @@ mod integration_tests {
|
||||
|
||||
// Verify results contain selected values
|
||||
assert_eq!(
|
||||
results.get("provider").map(|v| v.as_str()).flatten(),
|
||||
results.get("provider").and_then(|v| v.as_str()),
|
||||
Some("lxd")
|
||||
);
|
||||
assert_eq!(
|
||||
results.get("db_type").map(|v| v.as_str()).flatten(),
|
||||
results.get("db_type").and_then(|v| v.as_str()),
|
||||
Some("mysql")
|
||||
);
|
||||
}
|
||||
@ -3217,16 +3229,13 @@ max_items = "${constraint.tracker.udp.max_items}"
|
||||
|
||||
// Verify the interpolation worked
|
||||
assert_eq!(form.name, "Test Constraints Form");
|
||||
let udp_field = form
|
||||
.elements
|
||||
.iter()
|
||||
.find(|e| {
|
||||
if let FormElement::Field(f) = e {
|
||||
f.name == "udp_items"
|
||||
} else {
|
||||
false
|
||||
}
|
||||
});
|
||||
let udp_field = form.elements.iter().find(|e| {
|
||||
if let FormElement::Field(f) = e {
|
||||
f.name == "udp_items"
|
||||
} else {
|
||||
false
|
||||
}
|
||||
});
|
||||
|
||||
assert!(udp_field.is_some());
|
||||
let field = udp_field.unwrap().as_field().unwrap();
|
||||
|
||||
@ -23,19 +23,13 @@ pub fn format_results(
|
||||
match format {
|
||||
"json" => {
|
||||
let json_obj = serde_json::to_value(results).map_err(|e| {
|
||||
crate::Error::new(
|
||||
crate::error::ErrorKind::Other,
|
||||
format!("JSON serialization error: {}", e),
|
||||
)
|
||||
crate::ErrorWrapper::new(format!("JSON serialization error: {}", e))
|
||||
})?;
|
||||
Ok(serde_json::to_string_pretty(&json_obj)?)
|
||||
}
|
||||
"yaml" => {
|
||||
let yaml_string = serde_yaml::to_string(results).map_err(|e| {
|
||||
crate::Error::new(
|
||||
crate::error::ErrorKind::Other,
|
||||
format!("YAML serialization error: {}", e),
|
||||
)
|
||||
crate::ErrorWrapper::new(format!("YAML serialization error: {}", e))
|
||||
})?;
|
||||
Ok(yaml_string)
|
||||
}
|
||||
@ -46,16 +40,12 @@ pub fn format_results(
|
||||
}
|
||||
Ok(output)
|
||||
}
|
||||
"toml" => toml::to_string_pretty(results).map_err(|e| {
|
||||
crate::Error::new(
|
||||
crate::error::ErrorKind::Other,
|
||||
format!("TOML serialization error: {}", e),
|
||||
)
|
||||
}),
|
||||
_ => Err(crate::Error::new(
|
||||
crate::error::ErrorKind::ValidationFailed,
|
||||
format!("Unknown output format: {}", format),
|
||||
)),
|
||||
"toml" => toml::to_string_pretty(results)
|
||||
.map_err(|e| crate::ErrorWrapper::new(format!("TOML serialization error: {}", e))),
|
||||
_ => Err(crate::ErrorWrapper::new(format!(
|
||||
"Unknown output format: {}",
|
||||
format
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
@ -87,9 +77,8 @@ pub fn to_json_value(results: &HashMap<String, Value>) -> Value {
|
||||
|
||||
/// Convert results to JSON string
|
||||
pub fn to_json_string(results: &HashMap<String, Value>) -> crate::error::Result<String> {
|
||||
serde_json::to_string(&to_json_value(results)).map_err(|e| {
|
||||
crate::Error::new(crate::error::ErrorKind::Other, format!("JSON error: {}", e))
|
||||
})
|
||||
serde_json::to_string(&to_json_value(results))
|
||||
.map_err(|e| crate::ErrorWrapper::new(format!("JSON error: {}", e)))
|
||||
}
|
||||
|
||||
/// Encryption context controlling redaction/encryption behavior
|
||||
@ -160,7 +149,10 @@ pub fn resolve_encryption_config(
|
||||
backend.clone()
|
||||
} else if let Some(config) = global_config {
|
||||
// Priority 3: Global config
|
||||
config.default_backend.clone().unwrap_or_else(|| "age".to_string())
|
||||
config
|
||||
.default_backend
|
||||
.clone()
|
||||
.unwrap_or_else(|| "age".to_string())
|
||||
} else {
|
||||
// Priority 4: Hard default
|
||||
"age".to_string()
|
||||
@ -293,12 +285,7 @@ fn transform_sensitive_value(
|
||||
// Use unified encryption API with bridge module
|
||||
let spec = crate::encryption_bridge::field_to_backend_spec(field, default_backend)?;
|
||||
let ciphertext = encrypt::encrypt(&plaintext, &spec)
|
||||
.map_err(|e| {
|
||||
crate::Error::new(
|
||||
crate::error::ErrorKind::Other,
|
||||
format!("Encryption failed: {}", e),
|
||||
)
|
||||
})?;
|
||||
.map_err(|e| crate::ErrorWrapper::new(format!("Encryption failed: {}", e)))?;
|
||||
|
||||
return Ok(Value::String(ciphertext));
|
||||
}
|
||||
@ -454,7 +441,10 @@ mod tests {
|
||||
let context = EncryptionContext::encrypt_with("rustyvault", context_config.clone());
|
||||
|
||||
let (backend, _config) = resolve_encryption_config(&field, &context, None).unwrap();
|
||||
assert_eq!(backend, "age", "Field backend should have priority over context");
|
||||
assert_eq!(
|
||||
backend, "age",
|
||||
"Field backend should have priority over context"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -501,7 +491,10 @@ mod tests {
|
||||
let context = EncryptionContext::encrypt_with("rustyvault", context_config);
|
||||
|
||||
let (backend, _config) = resolve_encryption_config(&field, &context, None).unwrap();
|
||||
assert_eq!(backend, "rustyvault", "Context backend should be used when field has none");
|
||||
assert_eq!(
|
||||
backend, "rustyvault",
|
||||
"Context backend should be used when field has none"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -545,7 +538,10 @@ mod tests {
|
||||
|
||||
let context = EncryptionContext::noop();
|
||||
let (backend, _config) = resolve_encryption_config(&field, &context, None).unwrap();
|
||||
assert_eq!(backend, "age", "Should default to 'age' when nothing specified");
|
||||
assert_eq!(
|
||||
backend, "age",
|
||||
"Should default to 'age' when nothing specified"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -634,6 +630,7 @@ mod tests {
|
||||
}
|
||||
|
||||
// Helper function to create minimal FieldDefinition for tests
|
||||
#[allow(dead_code)]
|
||||
fn make_text_field(name: &str, sensitive: bool) -> crate::form_parser::FieldDefinition {
|
||||
crate::form_parser::FieldDefinition {
|
||||
name: name.to_string(),
|
||||
@ -678,7 +675,10 @@ mod tests {
|
||||
results.insert("username".to_string(), json!("bob"));
|
||||
results.insert("password".to_string(), json!("topsecret"));
|
||||
|
||||
let fields = vec![make_text_field("username", false), make_text_field("password", true)];
|
||||
let fields = vec![
|
||||
make_text_field("username", false),
|
||||
make_text_field("password", true),
|
||||
];
|
||||
|
||||
let context = EncryptionContext::redact_only();
|
||||
let output = format_results_secure(&results, &fields, "json", &context, None).unwrap();
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
//! Locale file loader for Fluent and TOML translations
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
use crate::error::{ErrorWrapper, Result};
|
||||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
@ -24,7 +24,7 @@ impl LocaleLoader {
|
||||
let locale_dir = self.locales_path.join(locale.to_string());
|
||||
|
||||
if !locale_dir.exists() {
|
||||
return Err(Error::i18n_failed(format!(
|
||||
return Err(ErrorWrapper::i18n_failed(format!(
|
||||
"Locale directory not found: {:?}",
|
||||
locale_dir
|
||||
)));
|
||||
@ -41,7 +41,7 @@ impl LocaleLoader {
|
||||
match fs::read_to_string(&path) {
|
||||
Ok(content) => resources.push(content),
|
||||
Err(e) => {
|
||||
return Err(Error::i18n_failed(format!(
|
||||
return Err(ErrorWrapper::i18n_failed(format!(
|
||||
"Failed to read {:?}: {}",
|
||||
path, e
|
||||
)))
|
||||
@ -51,7 +51,7 @@ impl LocaleLoader {
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(Error::i18n_failed(format!(
|
||||
return Err(ErrorWrapper::i18n_failed(format!(
|
||||
"Failed to read locale directory {:?}: {}",
|
||||
locale_dir, e
|
||||
)))
|
||||
@ -74,12 +74,12 @@ impl LocaleLoader {
|
||||
match fs::read_to_string(&toml_file) {
|
||||
Ok(content) => match toml::from_str::<toml::map::Map<String, toml::Value>>(&content) {
|
||||
Ok(root) => Ok(Self::flatten_toml(root, "")),
|
||||
Err(e) => Err(Error::i18n_failed(format!(
|
||||
Err(e) => Err(ErrorWrapper::i18n_failed(format!(
|
||||
"Failed to parse TOML file {:?}: {}",
|
||||
toml_file, e
|
||||
))),
|
||||
},
|
||||
Err(e) => Err(Error::i18n_failed(format!(
|
||||
Err(e) => Err(ErrorWrapper::i18n_failed(format!(
|
||||
"Failed to read TOML file {:?}: {}",
|
||||
toml_file, e
|
||||
))),
|
||||
|
||||
@ -8,7 +8,7 @@ mod resolver;
|
||||
pub use loader::LocaleLoader;
|
||||
pub use resolver::LocaleResolver;
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
use crate::error::{ErrorWrapper, Result};
|
||||
use fluent::FluentArgs;
|
||||
use fluent_bundle::{FluentBundle, FluentResource};
|
||||
use std::collections::HashMap;
|
||||
@ -48,10 +48,10 @@ impl I18nBundle {
|
||||
|
||||
for resource_str in loader.load_fluent(locale)? {
|
||||
let resource = FluentResource::try_new(resource_str)
|
||||
.map_err(|e| Error::i18n_failed(format!("Fluent parse error: {:?}", e)))?;
|
||||
.map_err(|e| ErrorWrapper::i18n_failed(format!("Fluent parse error: {:?}", e)))?;
|
||||
bundle
|
||||
.add_resource(resource)
|
||||
.map_err(|e| Error::i18n_failed(format!("Bundle add error: {:?}", e)))?;
|
||||
.map_err(|e| ErrorWrapper::i18n_failed(format!("Bundle add error: {:?}", e)))?;
|
||||
}
|
||||
|
||||
Ok(bundle)
|
||||
@ -60,7 +60,7 @@ impl I18nBundle {
|
||||
/// Translate a message key
|
||||
///
|
||||
/// Searches in order: main bundle → TOML translations → fallback bundle → missing key marker
|
||||
pub fn translate(&self, key: &str, args: Option<&FluentArgs>) -> String {
|
||||
pub fn translate(&self, key: &str, args: Option<&FluentArgs<'_>>) -> String {
|
||||
// Try main bundle
|
||||
if let Some(msg) = self.bundle.get_message(key) {
|
||||
if let Some(pattern) = msg.value() {
|
||||
@ -100,7 +100,7 @@ impl I18nBundle {
|
||||
}
|
||||
|
||||
/// Translate if the text looks like a key, otherwise return as-is
|
||||
pub fn translate_if_key(&self, text: &str, args: Option<&FluentArgs>) -> String {
|
||||
pub fn translate_if_key(&self, text: &str, args: Option<&FluentArgs<'_>>) -> String {
|
||||
if Self::is_i18n_key(text) {
|
||||
self.translate(text, args)
|
||||
} else {
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
#![allow(clippy::result_large_err)]
|
||||
|
||||
//! typedialog - Interactive forms and prompts library
|
||||
//!
|
||||
//! A powerful library and CLI tool for creating interactive forms and prompts
|
||||
@ -65,6 +67,9 @@ pub mod helpers;
|
||||
pub mod nickel;
|
||||
pub mod prompts;
|
||||
|
||||
#[cfg(feature = "ai_backend")]
|
||||
pub mod ai;
|
||||
|
||||
#[cfg(feature = "i18n")]
|
||||
pub mod config;
|
||||
|
||||
@ -86,7 +91,7 @@ pub use encrypt;
|
||||
// Re-export main types for convenient access
|
||||
pub use autocompletion::{FilterCompleter, HistoryCompleter, PatternCompleter};
|
||||
pub use backends::{BackendFactory, BackendType, FormBackend, RenderContext};
|
||||
pub use error::{Error, Result};
|
||||
pub use error::{Error, ErrorWrapper, Result};
|
||||
pub use form_parser::{DisplayItem, FieldDefinition, FieldType, FormDefinition};
|
||||
pub use helpers::{format_results, to_json_string, to_json_value};
|
||||
|
||||
@ -108,7 +113,9 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_version() {
|
||||
assert!(!VERSION.is_empty());
|
||||
// VERSION is compile-time constant from Cargo.toml, verify it's valid
|
||||
assert!(VERSION.chars().next().unwrap().is_ascii_digit());
|
||||
assert!(VERSION.contains('.'));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@ -135,7 +135,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_single_element_returns_none() {
|
||||
assert_eq!(AliasGenerator::generate(&vec!["name".to_string()]), None);
|
||||
assert_eq!(AliasGenerator::generate(&["name".to_string()]), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
//! - `nickel export` - Export evaluated Nickel to JSON
|
||||
//! - `nickel typecheck` - Validate Nickel syntax and types
|
||||
|
||||
use crate::error::{Error, ErrorKind};
|
||||
use crate::error::ErrorWrapper;
|
||||
use crate::Result;
|
||||
use std::path::Path;
|
||||
use std::process::Command;
|
||||
@ -24,29 +24,22 @@ impl NickelCli {
|
||||
.arg("--version")
|
||||
.output()
|
||||
.map_err(|e| {
|
||||
Error::new(
|
||||
ErrorKind::Other,
|
||||
format!(
|
||||
"Failed to execute 'nickel --version'. \
|
||||
ErrorWrapper::new(format!(
|
||||
"Failed to execute 'nickel --version'. \
|
||||
Is nickel installed? Install from: https://nickel-lang.org/install\n\
|
||||
Error: {}",
|
||||
e
|
||||
),
|
||||
)
|
||||
e
|
||||
))
|
||||
})?;
|
||||
|
||||
if !output.status.success() {
|
||||
return Err(Error::new(
|
||||
ErrorKind::Other,
|
||||
return Err(ErrorWrapper::new(
|
||||
"nickel command failed. Is nickel installed correctly?",
|
||||
));
|
||||
}
|
||||
|
||||
String::from_utf8(output.stdout).map_err(|e| {
|
||||
Error::new(
|
||||
ErrorKind::Other,
|
||||
format!("Invalid UTF-8 in nickel version output: {}", e),
|
||||
)
|
||||
ErrorWrapper::new(format!("Invalid UTF-8 in nickel version output: {}", e))
|
||||
})
|
||||
}
|
||||
|
||||
@ -73,36 +66,31 @@ impl NickelCli {
|
||||
}
|
||||
|
||||
let output = cmd.output().map_err(|e| {
|
||||
Error::new(
|
||||
ErrorKind::Other,
|
||||
format!(
|
||||
"Failed to execute 'nickel query' on {}: {}",
|
||||
path.display(),
|
||||
e
|
||||
),
|
||||
)
|
||||
ErrorWrapper::new(format!(
|
||||
"Failed to execute 'nickel query' on {}: {}",
|
||||
path.display(),
|
||||
e
|
||||
))
|
||||
})?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
return Err(Error::new(
|
||||
ErrorKind::ValidationFailed,
|
||||
format!("nickel query failed for {}: {}", path.display(), stderr),
|
||||
));
|
||||
return Err(ErrorWrapper::new(format!(
|
||||
"nickel query failed for {}: {}",
|
||||
path.display(),
|
||||
stderr
|
||||
)));
|
||||
}
|
||||
|
||||
let stdout = String::from_utf8(output.stdout).map_err(|e| {
|
||||
Error::new(
|
||||
ErrorKind::Other,
|
||||
format!("Invalid UTF-8 in nickel query output: {}", e),
|
||||
)
|
||||
ErrorWrapper::new(format!("Invalid UTF-8 in nickel query output: {}", e))
|
||||
})?;
|
||||
|
||||
serde_json::from_str(&stdout).map_err(|e| {
|
||||
Error::new(
|
||||
ErrorKind::Other,
|
||||
format!("Failed to parse nickel query output as JSON: {}", e),
|
||||
)
|
||||
ErrorWrapper::new(format!(
|
||||
"Failed to parse nickel query output as JSON: {}",
|
||||
e
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
@ -127,36 +115,31 @@ impl NickelCli {
|
||||
.arg(path)
|
||||
.output()
|
||||
.map_err(|e| {
|
||||
Error::new(
|
||||
ErrorKind::Other,
|
||||
format!(
|
||||
"Failed to execute 'nickel export' on {}: {}",
|
||||
path.display(),
|
||||
e
|
||||
),
|
||||
)
|
||||
ErrorWrapper::new(format!(
|
||||
"Failed to execute 'nickel export' on {}: {}",
|
||||
path.display(),
|
||||
e
|
||||
))
|
||||
})?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
return Err(Error::new(
|
||||
ErrorKind::ValidationFailed,
|
||||
format!("nickel export failed for {}: {}", path.display(), stderr),
|
||||
));
|
||||
return Err(ErrorWrapper::new(format!(
|
||||
"nickel export failed for {}: {}",
|
||||
path.display(),
|
||||
stderr
|
||||
)));
|
||||
}
|
||||
|
||||
let stdout = String::from_utf8(output.stdout).map_err(|e| {
|
||||
Error::new(
|
||||
ErrorKind::Other,
|
||||
format!("Invalid UTF-8 in nickel export output: {}", e),
|
||||
)
|
||||
ErrorWrapper::new(format!("Invalid UTF-8 in nickel export output: {}", e))
|
||||
})?;
|
||||
|
||||
serde_json::from_str(&stdout).map_err(|e| {
|
||||
Error::new(
|
||||
ErrorKind::Other,
|
||||
format!("Failed to parse nickel export output as JSON: {}", e),
|
||||
)
|
||||
ErrorWrapper::new(format!(
|
||||
"Failed to parse nickel export output as JSON: {}",
|
||||
e
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
@ -176,12 +159,9 @@ impl NickelCli {
|
||||
pub fn typecheck(path: &Path) -> Result<()> {
|
||||
// Execute typecheck from the file's directory to allow relative imports to resolve
|
||||
let parent_dir = path.parent().unwrap_or_else(|| std::path::Path::new("."));
|
||||
let filename = path.file_name().ok_or_else(|| {
|
||||
Error::new(
|
||||
ErrorKind::Other,
|
||||
"Cannot extract filename from path".to_string(),
|
||||
)
|
||||
})?;
|
||||
let filename = path
|
||||
.file_name()
|
||||
.ok_or_else(|| ErrorWrapper::new("Cannot extract filename from path".to_string()))?;
|
||||
|
||||
let output = Command::new("nickel")
|
||||
.current_dir(parent_dir)
|
||||
@ -189,26 +169,20 @@ impl NickelCli {
|
||||
.arg(filename)
|
||||
.output()
|
||||
.map_err(|e| {
|
||||
Error::new(
|
||||
ErrorKind::Other,
|
||||
format!(
|
||||
"Failed to execute 'nickel typecheck' on {}: {}",
|
||||
path.display(),
|
||||
e
|
||||
),
|
||||
)
|
||||
ErrorWrapper::new(format!(
|
||||
"Failed to execute 'nickel typecheck' on {}: {}",
|
||||
path.display(),
|
||||
e
|
||||
))
|
||||
})?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
return Err(Error::new(
|
||||
ErrorKind::ValidationFailed,
|
||||
format!(
|
||||
"nickel typecheck failed for {}:\n{}",
|
||||
path.display(),
|
||||
stderr
|
||||
),
|
||||
));
|
||||
return Err(ErrorWrapper::new(format!(
|
||||
"nickel typecheck failed for {}:\n{}",
|
||||
path.display(),
|
||||
stderr
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
@ -7,12 +7,14 @@
|
||||
//! - `std.string.NonEmpty` - Non-empty string
|
||||
//! - `std.string.length.min N` - Minimum string length
|
||||
//! - `std.string.length.max N` - Maximum string length
|
||||
//! - `std.string.Email` - Valid email address format
|
||||
//! - `std.string.Url` - Valid URL format (http/https/ftp/ftps)
|
||||
//! - `std.number.between A B` - Number in range [A, B]
|
||||
//! - `std.number.greater_than N` - Number > N
|
||||
//! - `std.number.less_than N` - Number < N
|
||||
|
||||
use super::schema_ir::{NickelFieldIR, NickelSchemaIR, NickelType};
|
||||
use crate::error::{Error, ErrorKind};
|
||||
use crate::error::ErrorWrapper;
|
||||
use crate::Result;
|
||||
use serde_json::Value;
|
||||
|
||||
@ -72,6 +74,14 @@ impl ContractValidator {
|
||||
}
|
||||
}
|
||||
|
||||
if predicate.contains("std.string.Email") {
|
||||
return Self::validate_email(value);
|
||||
}
|
||||
|
||||
if predicate.contains("std.string.Url") {
|
||||
return Self::validate_url(value);
|
||||
}
|
||||
|
||||
// Unknown predicate - pass validation
|
||||
Ok(())
|
||||
}
|
||||
@ -81,18 +91,14 @@ impl ContractValidator {
|
||||
match value {
|
||||
Value::String(s) => {
|
||||
if s.is_empty() {
|
||||
Err(Error::new(
|
||||
ErrorKind::ValidationFailed,
|
||||
Err(ErrorWrapper::new(
|
||||
"String must not be empty (std.string.NonEmpty)".to_string(),
|
||||
))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
_ => Err(Error::new(
|
||||
ErrorKind::ValidationFailed,
|
||||
"Expected string value".to_string(),
|
||||
)),
|
||||
_ => Err(ErrorWrapper::new("Expected string value".to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
@ -101,21 +107,15 @@ impl ContractValidator {
|
||||
match value {
|
||||
Value::String(s) => {
|
||||
if s.len() < min {
|
||||
Err(Error::new(
|
||||
ErrorKind::ValidationFailed,
|
||||
format!(
|
||||
"String must be at least {} characters (std.string.length.min {})",
|
||||
min, min
|
||||
),
|
||||
))
|
||||
Err(ErrorWrapper::new(format!(
|
||||
"String must be at least {} characters (std.string.length.min {})",
|
||||
min, min
|
||||
)))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
_ => Err(Error::new(
|
||||
ErrorKind::ValidationFailed,
|
||||
"Expected string value".to_string(),
|
||||
)),
|
||||
_ => Err(ErrorWrapper::new("Expected string value".to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
@ -124,21 +124,15 @@ impl ContractValidator {
|
||||
match value {
|
||||
Value::String(s) => {
|
||||
if s.len() > max {
|
||||
Err(Error::new(
|
||||
ErrorKind::ValidationFailed,
|
||||
format!(
|
||||
"String must be at most {} characters (std.string.length.max {})",
|
||||
max, max
|
||||
),
|
||||
))
|
||||
Err(ErrorWrapper::new(format!(
|
||||
"String must be at most {} characters (std.string.length.max {})",
|
||||
max, max
|
||||
)))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
_ => Err(Error::new(
|
||||
ErrorKind::ValidationFailed,
|
||||
"Expected string value".to_string(),
|
||||
)),
|
||||
_ => Err(ErrorWrapper::new("Expected string value".to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
@ -150,25 +144,16 @@ impl ContractValidator {
|
||||
if num >= a && num <= b {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(Error::new(
|
||||
ErrorKind::ValidationFailed,
|
||||
format!(
|
||||
"Number must be between {} and {} (std.number.between {} {})",
|
||||
a, b, a, b
|
||||
),
|
||||
))
|
||||
Err(ErrorWrapper::new(format!(
|
||||
"Number must be between {} and {} (std.number.between {} {})",
|
||||
a, b, a, b
|
||||
)))
|
||||
}
|
||||
} else {
|
||||
Err(Error::new(
|
||||
ErrorKind::ValidationFailed,
|
||||
"Invalid number value".to_string(),
|
||||
))
|
||||
Err(ErrorWrapper::new("Invalid number value".to_string()))
|
||||
}
|
||||
}
|
||||
_ => Err(Error::new(
|
||||
ErrorKind::ValidationFailed,
|
||||
"Expected number value".to_string(),
|
||||
)),
|
||||
_ => Err(ErrorWrapper::new("Expected number value".to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
@ -180,25 +165,16 @@ impl ContractValidator {
|
||||
if val > n {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(Error::new(
|
||||
ErrorKind::ValidationFailed,
|
||||
format!(
|
||||
"Number must be greater than {} (std.number.greater_than {})",
|
||||
n, n
|
||||
),
|
||||
))
|
||||
Err(ErrorWrapper::new(format!(
|
||||
"Number must be greater than {} (std.number.greater_than {})",
|
||||
n, n
|
||||
)))
|
||||
}
|
||||
} else {
|
||||
Err(Error::new(
|
||||
ErrorKind::ValidationFailed,
|
||||
"Invalid number value".to_string(),
|
||||
))
|
||||
Err(ErrorWrapper::new("Invalid number value".to_string()))
|
||||
}
|
||||
}
|
||||
_ => Err(Error::new(
|
||||
ErrorKind::ValidationFailed,
|
||||
"Expected number value".to_string(),
|
||||
)),
|
||||
_ => Err(ErrorWrapper::new("Expected number value".to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
@ -210,25 +186,16 @@ impl ContractValidator {
|
||||
if val < n {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(Error::new(
|
||||
ErrorKind::ValidationFailed,
|
||||
format!(
|
||||
"Number must be less than {} (std.number.less_than {})",
|
||||
n, n
|
||||
),
|
||||
))
|
||||
Err(ErrorWrapper::new(format!(
|
||||
"Number must be less than {} (std.number.less_than {})",
|
||||
n, n
|
||||
)))
|
||||
}
|
||||
} else {
|
||||
Err(Error::new(
|
||||
ErrorKind::ValidationFailed,
|
||||
"Invalid number value".to_string(),
|
||||
))
|
||||
Err(ErrorWrapper::new("Invalid number value".to_string()))
|
||||
}
|
||||
}
|
||||
_ => Err(Error::new(
|
||||
ErrorKind::ValidationFailed,
|
||||
"Expected number value".to_string(),
|
||||
)),
|
||||
_ => Err(ErrorWrapper::new("Expected number value".to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
@ -259,6 +226,101 @@ impl ContractValidator {
|
||||
|
||||
Some((a, b))
|
||||
}
|
||||
|
||||
/// Validate email address format
|
||||
///
|
||||
/// Uses a simple regex pattern to check basic email format.
|
||||
/// Pattern: local@domain where local can contain alphanumeric, dots, hyphens, underscores
|
||||
/// and domain must have at least one dot.
|
||||
pub fn validate_email(value: &Value) -> Result<()> {
|
||||
match value {
|
||||
Value::String(s) => {
|
||||
if s.is_empty() {
|
||||
return Err(ErrorWrapper::new(
|
||||
"Email address cannot be empty".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let at_count = s.matches('@').count();
|
||||
if at_count != 1 {
|
||||
return Err(ErrorWrapper::new(
|
||||
"Email must contain exactly one @ symbol".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let parts: Vec<&str> = s.split('@').collect();
|
||||
let local = parts[0];
|
||||
let domain = parts[1];
|
||||
|
||||
if local.is_empty() {
|
||||
return Err(ErrorWrapper::new(
|
||||
"Email local part cannot be empty".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if domain.is_empty() {
|
||||
return Err(ErrorWrapper::new(
|
||||
"Email domain cannot be empty".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if !domain.contains('.') {
|
||||
return Err(ErrorWrapper::new(
|
||||
"Email domain must contain at least one dot".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if domain.starts_with('.') || domain.ends_with('.') {
|
||||
return Err(ErrorWrapper::new(
|
||||
"Email domain cannot start or end with a dot".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
_ => Err(ErrorWrapper::new(
|
||||
"Expected string value for email".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate URL format
|
||||
///
|
||||
/// Checks for basic URL structure: scheme://host with optional path.
|
||||
/// Accepted schemes: http, https, ftp, ftps.
|
||||
pub fn validate_url(value: &Value) -> Result<()> {
|
||||
match value {
|
||||
Value::String(s) => {
|
||||
if s.is_empty() {
|
||||
return Err(ErrorWrapper::new("URL cannot be empty".to_string()));
|
||||
}
|
||||
|
||||
let valid_schemes = ["http://", "https://", "ftp://", "ftps://"];
|
||||
let has_valid_scheme = valid_schemes.iter().any(|scheme| s.starts_with(scheme));
|
||||
|
||||
if !has_valid_scheme {
|
||||
return Err(ErrorWrapper::new(format!(
|
||||
"URL must start with one of: {}",
|
||||
valid_schemes.join(", ")
|
||||
)));
|
||||
}
|
||||
|
||||
let scheme_end = s.find("://").unwrap() + 3;
|
||||
let rest = &s[scheme_end..];
|
||||
|
||||
if rest.is_empty() {
|
||||
return Err(ErrorWrapper::new(
|
||||
"URL must contain a host after the scheme".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
_ => Err(ErrorWrapper::new(
|
||||
"Expected string value for URL".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Analyzer for inferring conditional expressions from Nickel contracts
|
||||
|
||||
@ -27,6 +27,9 @@
|
||||
//! contract_call: None,
|
||||
//! group: None,
|
||||
//! fragment_marker: None,
|
||||
//! is_array_of_records: false,
|
||||
//! array_element_fields: None,
|
||||
//! encryption_metadata: None,
|
||||
//! },
|
||||
//! ],
|
||||
//! };
|
||||
@ -127,7 +130,7 @@ mod tests {
|
||||
fragment_marker: None,
|
||||
is_array_of_records: false,
|
||||
array_element_fields: None,
|
||||
encryption_metadata: None,
|
||||
encryption_metadata: None,
|
||||
},
|
||||
NickelFieldIR {
|
||||
path: vec![
|
||||
@ -148,7 +151,7 @@ mod tests {
|
||||
fragment_marker: None,
|
||||
is_array_of_records: false,
|
||||
array_element_fields: None,
|
||||
encryption_metadata: None,
|
||||
encryption_metadata: None,
|
||||
},
|
||||
NickelFieldIR {
|
||||
path: vec!["ssh_credentials".to_string(), "username".to_string()],
|
||||
@ -164,7 +167,7 @@ mod tests {
|
||||
fragment_marker: None,
|
||||
is_array_of_records: false,
|
||||
array_element_fields: None,
|
||||
encryption_metadata: None,
|
||||
encryption_metadata: None,
|
||||
},
|
||||
NickelFieldIR {
|
||||
path: vec!["ssh_credentials".to_string(), "port".to_string()],
|
||||
@ -180,7 +183,7 @@ mod tests {
|
||||
fragment_marker: None,
|
||||
is_array_of_records: false,
|
||||
array_element_fields: None,
|
||||
encryption_metadata: None,
|
||||
encryption_metadata: None,
|
||||
},
|
||||
NickelFieldIR {
|
||||
path: vec![
|
||||
@ -200,7 +203,7 @@ mod tests {
|
||||
fragment_marker: None,
|
||||
is_array_of_records: false,
|
||||
array_element_fields: None,
|
||||
encryption_metadata: None,
|
||||
encryption_metadata: None,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
@ -158,7 +158,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_extract_attributes() {
|
||||
let contract = "String | Sensitive Backend=\"age\" Key=\"/path/to/key\" Vault=\"http://vault\"";
|
||||
let contract =
|
||||
"String | Sensitive Backend=\"age\" Key=\"/path/to/key\" Vault=\"http://vault\"";
|
||||
let attrs = EncryptionContractParser::extract_attributes(contract);
|
||||
|
||||
assert_eq!(attrs.get("Backend"), Some(&"age".to_string()));
|
||||
|
||||
@ -38,7 +38,7 @@ impl FieldMapper {
|
||||
// alias must be unique, if present
|
||||
if let Some(ref alias) = field.alias {
|
||||
if let Some(existing) = alias_to_field.get(alias) {
|
||||
return Err(crate::error::Error::validation_failed(format!(
|
||||
return Err(crate::ErrorWrapper::validation_failed(format!(
|
||||
"Alias collision detected for '{}': used by multiple fields:\n - {:?}\n - {:?}\n\n\
|
||||
Solution: Use flat_names (with '-' separator) instead:\n - {}\n - {}",
|
||||
alias,
|
||||
@ -123,7 +123,7 @@ mod tests {
|
||||
fragment_marker: None,
|
||||
is_array_of_records: false,
|
||||
array_element_fields: None,
|
||||
encryption_metadata: None,
|
||||
encryption_metadata: None,
|
||||
},
|
||||
NickelFieldIR {
|
||||
path: vec!["ssh_credentials".to_string(), "username".to_string()],
|
||||
@ -139,7 +139,7 @@ mod tests {
|
||||
fragment_marker: None,
|
||||
is_array_of_records: false,
|
||||
array_element_fields: None,
|
||||
encryption_metadata: None,
|
||||
encryption_metadata: None,
|
||||
},
|
||||
NickelFieldIR {
|
||||
path: vec!["simple".to_string()],
|
||||
@ -155,7 +155,7 @@ mod tests {
|
||||
fragment_marker: None,
|
||||
is_array_of_records: false,
|
||||
array_element_fields: None,
|
||||
encryption_metadata: None,
|
||||
encryption_metadata: None,
|
||||
},
|
||||
],
|
||||
}
|
||||
@ -230,8 +230,8 @@ mod tests {
|
||||
#[test]
|
||||
fn test_no_collision_with_hyphen_separator() {
|
||||
// This test verifies the core benefit: hyphen separator prevents collisions
|
||||
let path1 = vec!["ssh_credentials".to_string(), "username".to_string()];
|
||||
let path2 = vec!["ssh".to_string(), "credentials_username".to_string()];
|
||||
let path1 = ["ssh_credentials".to_string(), "username".to_string()];
|
||||
let path2 = ["ssh".to_string(), "credentials_username".to_string()];
|
||||
|
||||
let flat1 = path1.join("-"); // ssh_credentials-username
|
||||
let flat2 = path2.join("-"); // ssh-credentials_username
|
||||
|
||||
@ -8,7 +8,7 @@
|
||||
//! - Additional locales parsed from schema metadata
|
||||
|
||||
use super::schema_ir::NickelSchemaIR;
|
||||
use crate::error::{Error, ErrorKind};
|
||||
use crate::error::ErrorWrapper;
|
||||
use crate::Result;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
@ -55,16 +55,14 @@ impl I18nExtractor {
|
||||
// Generate .ftl files for each locale
|
||||
for (locale, messages) in &translations {
|
||||
let ftl_dir = output_dir.join(locale);
|
||||
std::fs::create_dir_all(&ftl_dir).map_err(|e| {
|
||||
Error::new(ErrorKind::Io, format!("Failed to create locale dir: {}", e))
|
||||
})?;
|
||||
std::fs::create_dir_all(&ftl_dir)
|
||||
.map_err(|e| ErrorWrapper::new(format!("Failed to create locale dir: {}", e)))?;
|
||||
|
||||
let ftl_path = ftl_dir.join("forms.ftl");
|
||||
let ftl_content = Self::generate_ftl_content(messages);
|
||||
|
||||
std::fs::write(&ftl_path, ftl_content).map_err(|e| {
|
||||
Error::new(ErrorKind::Io, format!("Failed to write .ftl file: {}", e))
|
||||
})?;
|
||||
std::fs::write(&ftl_path, ftl_content)
|
||||
.map_err(|e| ErrorWrapper::new(format!("Failed to write .ftl file: {}", e)))?;
|
||||
}
|
||||
|
||||
// Return mapping of field_name -> i18n_key
|
||||
|
||||
@ -11,7 +11,7 @@
|
||||
//! - Fragment markers from `# @fragment: name` comments
|
||||
|
||||
use super::schema_ir::{NickelFieldIR, NickelSchemaIR, NickelType};
|
||||
use crate::error::{Error, ErrorKind};
|
||||
use crate::error::ErrorWrapper;
|
||||
use crate::Result;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
@ -35,12 +35,9 @@ impl MetadataParser {
|
||||
///
|
||||
/// Returns error if JSON structure is invalid or required fields are missing
|
||||
pub fn parse(json: Value) -> Result<NickelSchemaIR> {
|
||||
let obj = json.as_object().ok_or_else(|| {
|
||||
Error::new(
|
||||
ErrorKind::ValidationFailed,
|
||||
"Expected JSON object from nickel query",
|
||||
)
|
||||
})?;
|
||||
let obj = json
|
||||
.as_object()
|
||||
.ok_or_else(|| ErrorWrapper::new("Expected JSON object from nickel query"))?;
|
||||
|
||||
let mut fields = Vec::new();
|
||||
Self::extract_fields(obj, Vec::new(), &mut fields)?;
|
||||
@ -224,7 +221,7 @@ impl MetadataParser {
|
||||
// For nested objects, extract fields
|
||||
let mut nested_fields = Vec::new();
|
||||
let path_copy = _path.to_vec();
|
||||
let _ = Self::extract_fields(obj, path_copy, &mut nested_fields);
|
||||
drop(Self::extract_fields(obj, path_copy, &mut nested_fields));
|
||||
NickelType::Record(nested_fields)
|
||||
}
|
||||
}
|
||||
@ -238,7 +235,7 @@ impl MetadataParser {
|
||||
source_path: &Path,
|
||||
) -> Result<HashMap<String, String>> {
|
||||
let source_code = std::fs::read_to_string(source_path)
|
||||
.map_err(|e| Error::new(ErrorKind::Io, format!("Failed to read source file: {}", e)))?;
|
||||
.map_err(|e| ErrorWrapper::new(format!("Failed to read source file: {}", e)))?;
|
||||
|
||||
let mut markers = HashMap::new();
|
||||
let mut current_fragment: Option<String> = None;
|
||||
|
||||
@ -102,10 +102,7 @@ impl RoundtripConfig {
|
||||
}
|
||||
|
||||
let input_source = fs::read_to_string(&self.input_ncl).map_err(|e| {
|
||||
crate::error::Error::new(
|
||||
crate::error::ErrorKind::Other,
|
||||
format!("Failed to read input file: {}", e),
|
||||
)
|
||||
crate::error::ErrorWrapper::new(format!("Failed to read input file: {}", e))
|
||||
})?;
|
||||
|
||||
let input_contracts = ContractParser::parse_source(&input_source)?;
|
||||
@ -170,10 +167,7 @@ impl RoundtripConfig {
|
||||
}
|
||||
|
||||
fs::write(&self.output_ncl, &output_nickel).map_err(|e| {
|
||||
crate::error::Error::new(
|
||||
crate::error::ErrorKind::Other,
|
||||
format!("Failed to write output file: {}", e),
|
||||
)
|
||||
crate::error::ErrorWrapper::new(format!("Failed to write output file: {}", e))
|
||||
})?;
|
||||
|
||||
// Step 5: Validate if requested
|
||||
@ -220,10 +214,7 @@ impl RoundtripConfig {
|
||||
}
|
||||
|
||||
let input_source = fs::read_to_string(&self.input_ncl).map_err(|e| {
|
||||
crate::error::Error::new(
|
||||
crate::error::ErrorKind::Other,
|
||||
format!("Failed to read input file: {}", e),
|
||||
)
|
||||
crate::error::ErrorWrapper::new(format!("Failed to read input file: {}", e))
|
||||
})?;
|
||||
|
||||
let input_contracts = ContractParser::parse_source(&input_source)?;
|
||||
@ -288,10 +279,7 @@ impl RoundtripConfig {
|
||||
}
|
||||
|
||||
fs::write(&self.output_ncl, &output_nickel).map_err(|e| {
|
||||
crate::error::Error::new(
|
||||
crate::error::ErrorKind::Other,
|
||||
format!("Failed to write output file: {}", e),
|
||||
)
|
||||
crate::error::ErrorWrapper::new(format!("Failed to write output file: {}", e))
|
||||
})?;
|
||||
|
||||
// Step 5: Validate if requested
|
||||
@ -333,10 +321,7 @@ impl RoundtripConfig {
|
||||
) -> Result<HashMap<String, Value>> {
|
||||
// Read form definition
|
||||
let form_content = fs::read_to_string(form_path).map_err(|e| {
|
||||
crate::error::Error::new(
|
||||
crate::error::ErrorKind::Other,
|
||||
format!("Failed to read form file: {}", e),
|
||||
)
|
||||
crate::error::ErrorWrapper::new(format!("Failed to read form file: {}", e))
|
||||
})?;
|
||||
|
||||
// Parse TOML form definition
|
||||
|
||||
@ -473,7 +473,7 @@ mod tests {
|
||||
fragment_marker: None,
|
||||
is_array_of_records: false,
|
||||
array_element_fields: None,
|
||||
encryption_metadata: None,
|
||||
encryption_metadata: None,
|
||||
}],
|
||||
};
|
||||
|
||||
@ -514,7 +514,7 @@ mod tests {
|
||||
fragment_marker: None,
|
||||
is_array_of_records: true,
|
||||
array_element_fields: None,
|
||||
encryption_metadata: None,
|
||||
encryption_metadata: None,
|
||||
}],
|
||||
};
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@
|
||||
//! - Template loops, conditionals, and filters
|
||||
//! - Passing form results as context to templates
|
||||
|
||||
use crate::error::{Error, ErrorKind};
|
||||
use crate::error::ErrorWrapper;
|
||||
use crate::Result;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
@ -56,12 +56,8 @@ impl TemplateEngine {
|
||||
#[cfg(feature = "templates")]
|
||||
{
|
||||
// Read template file
|
||||
let template_content = fs::read_to_string(template_path).map_err(|e| {
|
||||
Error::new(
|
||||
ErrorKind::Io,
|
||||
format!("Failed to read template file: {}", e),
|
||||
)
|
||||
})?;
|
||||
let template_content = fs::read_to_string(template_path)
|
||||
.map_err(|e| ErrorWrapper::new(format!("Failed to read template file: {}", e)))?;
|
||||
|
||||
// Add template to engine
|
||||
let template_name = template_path
|
||||
@ -71,30 +67,21 @@ impl TemplateEngine {
|
||||
|
||||
self.tera
|
||||
.add_raw_template(template_name, &template_content)
|
||||
.map_err(|e| {
|
||||
Error::new(
|
||||
ErrorKind::ValidationFailed,
|
||||
format!("Failed to add template: {}", e),
|
||||
)
|
||||
})?;
|
||||
.map_err(|e| ErrorWrapper::new(format!("Failed to add template: {}", e)))?;
|
||||
|
||||
// Build context from values with dual-key support if mapper provided
|
||||
let mut context = Context::new();
|
||||
self.build_dual_key_context(&mut context, values, mapper)?;
|
||||
|
||||
// Render template
|
||||
self.tera.render(template_name, &context).map_err(|e| {
|
||||
Error::new(
|
||||
ErrorKind::ValidationFailed,
|
||||
format!("Failed to render template: {}", e),
|
||||
)
|
||||
})
|
||||
self.tera
|
||||
.render(template_name, &context)
|
||||
.map_err(|e| ErrorWrapper::new(format!("Failed to render template: {}", e)))
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "templates"))]
|
||||
{
|
||||
Err(Error::new(
|
||||
ErrorKind::Other,
|
||||
Err(ErrorWrapper::new(
|
||||
"Template feature not enabled. Enable with --features templates".to_string(),
|
||||
))
|
||||
}
|
||||
@ -118,30 +105,21 @@ impl TemplateEngine {
|
||||
// Add template to engine
|
||||
self.tera
|
||||
.add_raw_template("inline", template)
|
||||
.map_err(|e| {
|
||||
Error::new(
|
||||
ErrorKind::ValidationFailed,
|
||||
format!("Failed to add template: {}", e),
|
||||
)
|
||||
})?;
|
||||
.map_err(|e| ErrorWrapper::new(format!("Failed to add template: {}", e)))?;
|
||||
|
||||
// Build context from values with dual-key support if mapper provided
|
||||
let mut context = Context::new();
|
||||
self.build_dual_key_context(&mut context, values, mapper)?;
|
||||
|
||||
// Render template
|
||||
self.tera.render("inline", &context).map_err(|e| {
|
||||
Error::new(
|
||||
ErrorKind::ValidationFailed,
|
||||
format!("Failed to render template: {}", e),
|
||||
)
|
||||
})
|
||||
self.tera
|
||||
.render("inline", &context)
|
||||
.map_err(|e| ErrorWrapper::new(format!("Failed to render template: {}", e)))
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "templates"))]
|
||||
{
|
||||
Err(Error::new(
|
||||
ErrorKind::Other,
|
||||
Err(ErrorWrapper::new(
|
||||
"Template feature not enabled. Enable with --features templates".to_string(),
|
||||
))
|
||||
}
|
||||
@ -202,7 +180,6 @@ mod tests {
|
||||
fn test_template_engine_new() {
|
||||
let _engine = TemplateEngine::new();
|
||||
// Just verify it can be created
|
||||
assert!(true);
|
||||
}
|
||||
|
||||
#[cfg(feature = "templates")]
|
||||
@ -263,66 +240,4 @@ age : Number = {{ age }}
|
||||
let output = result.unwrap();
|
||||
assert!(output.contains("monitoring_enabled"));
|
||||
}
|
||||
|
||||
#[cfg(feature = "templates")]
|
||||
#[test]
|
||||
fn test_render_values_template() {
|
||||
let mut engine = TemplateEngine::new();
|
||||
let mut values = HashMap::new();
|
||||
|
||||
// Simulate form results with typical values
|
||||
values.insert("environment_name".to_string(), json!("production"));
|
||||
values.insert("provider".to_string(), json!("lxd"));
|
||||
values.insert("lxd_profile_name".to_string(), json!("torrust-profile"));
|
||||
values.insert("ssh_private_key_path".to_string(), json!("~/.ssh/id_rsa"));
|
||||
values.insert("ssh_username".to_string(), json!("torrust"));
|
||||
values.insert("ssh_port".to_string(), json!(22));
|
||||
values.insert("database_driver".to_string(), json!("sqlite3"));
|
||||
values.insert("sqlite_database_name".to_string(), json!("tracker.db"));
|
||||
values.insert(
|
||||
"udp_tracker_bind_address".to_string(),
|
||||
json!("0.0.0.0:6969"),
|
||||
);
|
||||
values.insert(
|
||||
"http_tracker_bind_address".to_string(),
|
||||
json!("0.0.0.0:7070"),
|
||||
);
|
||||
values.insert("http_api_bind_address".to_string(), json!("0.0.0.0:1212"));
|
||||
values.insert("http_api_admin_token".to_string(), json!("secret-token"));
|
||||
values.insert("tracker_private_mode".to_string(), json!(false));
|
||||
values.insert("enable_prometheus".to_string(), json!(false));
|
||||
values.insert("enable_grafana".to_string(), json!(false));
|
||||
|
||||
// Find template file by checking from project root and current directory
|
||||
let template_path =
|
||||
if std::path::Path::new("provisioning/templates/values-template.ncl.j2").exists() {
|
||||
std::path::Path::new("provisioning/templates/values-template.ncl.j2")
|
||||
} else {
|
||||
std::path::Path::new("../../provisioning/templates/values-template.ncl.j2")
|
||||
};
|
||||
|
||||
let result = engine.render_file(template_path, &values, None);
|
||||
|
||||
assert!(result.is_ok(), "Template rendering failed: {:?}", result);
|
||||
let output = result.unwrap();
|
||||
|
||||
// Verify key template content is present and correct
|
||||
assert!(output.contains("validators.ValidEnvironmentName \"production\""));
|
||||
assert!(output.contains("profile_name = \"torrust-profile\""));
|
||||
assert!(output.contains("username = validators_username.ValidUsername \"torrust\""));
|
||||
assert!(output.contains("port = validators_common.ValidPort 22"));
|
||||
assert!(output.contains("driver = \"sqlite3\""));
|
||||
assert!(output.contains("database_name = \"tracker.db\""));
|
||||
assert!(output.contains("private = false"));
|
||||
assert!(
|
||||
output.contains("bind_address = validators_network.ValidBindAddress \"0.0.0.0:6969\"")
|
||||
);
|
||||
|
||||
// Verify it's long enough to contain the full template
|
||||
assert!(
|
||||
output.len() > 2000,
|
||||
"Output too short: {} bytes",
|
||||
output.len()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -557,7 +557,7 @@ mod tests {
|
||||
fragment_marker: None,
|
||||
is_array_of_records: false,
|
||||
array_element_fields: None,
|
||||
encryption_metadata: None,
|
||||
encryption_metadata: None,
|
||||
},
|
||||
NickelFieldIR {
|
||||
path: vec!["age".to_string()],
|
||||
@ -573,7 +573,7 @@ mod tests {
|
||||
fragment_marker: None,
|
||||
is_array_of_records: false,
|
||||
array_element_fields: None,
|
||||
encryption_metadata: None,
|
||||
encryption_metadata: None,
|
||||
},
|
||||
],
|
||||
};
|
||||
@ -618,7 +618,7 @@ mod tests {
|
||||
fragment_marker: None,
|
||||
is_array_of_records: false,
|
||||
array_element_fields: None,
|
||||
encryption_metadata: None,
|
||||
encryption_metadata: None,
|
||||
},
|
||||
NickelFieldIR {
|
||||
path: vec!["settings_theme".to_string()],
|
||||
@ -634,7 +634,7 @@ mod tests {
|
||||
fragment_marker: None,
|
||||
is_array_of_records: false,
|
||||
array_element_fields: None,
|
||||
encryption_metadata: None,
|
||||
encryption_metadata: None,
|
||||
},
|
||||
],
|
||||
};
|
||||
@ -642,7 +642,7 @@ mod tests {
|
||||
let form = TomlGenerator::generate(&schema, false, true).unwrap();
|
||||
|
||||
// Should have display items for groups
|
||||
assert!(form.items.len() > 0);
|
||||
assert!(!form.items.is_empty());
|
||||
|
||||
// Check fields are grouped
|
||||
assert_eq!(form.fields[0].group, Some("user".to_string()));
|
||||
@ -697,7 +697,7 @@ mod tests {
|
||||
fragment_marker: None,
|
||||
is_array_of_records: false,
|
||||
array_element_fields: None,
|
||||
encryption_metadata: None,
|
||||
encryption_metadata: None,
|
||||
};
|
||||
|
||||
let options = TomlGenerator::extract_enum_options(&field);
|
||||
@ -727,7 +727,7 @@ mod tests {
|
||||
fragment_marker: None,
|
||||
is_array_of_records: false,
|
||||
array_element_fields: None,
|
||||
encryption_metadata: None,
|
||||
encryption_metadata: None,
|
||||
};
|
||||
|
||||
let schema = NickelSchemaIR {
|
||||
@ -757,7 +757,7 @@ mod tests {
|
||||
fragment_marker: None,
|
||||
is_array_of_records: false,
|
||||
array_element_fields: None,
|
||||
encryption_metadata: None,
|
||||
encryption_metadata: None,
|
||||
};
|
||||
|
||||
let udp_trackers_field = NickelFieldIR {
|
||||
|
||||
@ -1,798 +0,0 @@
|
||||
//! TOML Form Generator
|
||||
//!
|
||||
//! Converts Nickel schema intermediate representation (NickelSchemaIR)
|
||||
//! into typedialog FormDefinition TOML format.
|
||||
//!
|
||||
//! Handles type mapping, metadata extraction, flatten/unflatten operations,
|
||||
//! semantic grouping, and conditional expression inference from contracts.
|
||||
|
||||
use super::contracts::ContractAnalyzer;
|
||||
use super::schema_ir::{NickelFieldIR, NickelSchemaIR, NickelType};
|
||||
use crate::error::Result;
|
||||
use crate::form_parser::{DisplayItem, FieldDefinition, FieldType, FormDefinition, SelectOption};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Generator for converting Nickel schemas to typedialog TOML forms
|
||||
pub struct TomlGenerator;
|
||||
|
||||
impl TomlGenerator {
|
||||
/// Convert a Nickel schema IR to a typedialog FormDefinition
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `schema` - The Nickel schema intermediate representation
|
||||
/// * `flatten_records` - Whether to flatten nested records into flat field names
|
||||
/// * `use_groups` - Whether to use semantic grouping for form organization
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// FormDefinition ready to be serialized to TOML
|
||||
pub fn generate(
|
||||
schema: &NickelSchemaIR,
|
||||
flatten_records: bool,
|
||||
use_groups: bool,
|
||||
) -> Result<FormDefinition> {
|
||||
let mut fields = Vec::new();
|
||||
let mut items = Vec::new();
|
||||
let mut group_order: HashMap<String, usize> = HashMap::new();
|
||||
let mut current_order = 0;
|
||||
|
||||
// First pass: collect all groups
|
||||
if use_groups {
|
||||
for field in &schema.fields {
|
||||
if let Some(group) = &field.group {
|
||||
group_order.entry(group.clone()).or_insert_with(|| {
|
||||
let order = current_order;
|
||||
current_order += 1;
|
||||
order
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Generate display items for groups (headers)
|
||||
let mut item_order = 0;
|
||||
if use_groups {
|
||||
for group in &schema
|
||||
.fields
|
||||
.iter()
|
||||
.filter_map(|f| f.group.as_ref())
|
||||
.collect::<std::collections::HashSet<_>>()
|
||||
{
|
||||
items.push(DisplayItem {
|
||||
name: format!("{}_header", group),
|
||||
item_type: "section".to_string(),
|
||||
title: Some(format_group_title(group)),
|
||||
border_top: Some(true),
|
||||
group: Some(group.to_string()),
|
||||
order: item_order,
|
||||
content: None,
|
||||
template: None,
|
||||
border_bottom: None,
|
||||
margin_left: None,
|
||||
border_margin_left: None,
|
||||
content_margin_left: None,
|
||||
align: None,
|
||||
when: None,
|
||||
includes: None,
|
||||
border_top_char: None,
|
||||
border_top_len: None,
|
||||
border_top_l: None,
|
||||
border_top_r: None,
|
||||
border_bottom_char: None,
|
||||
border_bottom_len: None,
|
||||
border_bottom_l: None,
|
||||
border_bottom_r: None,
|
||||
i18n: None,
|
||||
});
|
||||
item_order += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Second pass: generate fields
|
||||
let mut field_order = item_order + 100; // Offset to allow items to display first
|
||||
for field in &schema.fields {
|
||||
let form_field =
|
||||
Self::field_ir_to_definition(field, flatten_records, field_order, schema)?;
|
||||
fields.push(form_field);
|
||||
field_order += 1;
|
||||
}
|
||||
|
||||
Ok(FormDefinition {
|
||||
name: schema.name.clone(),
|
||||
description: schema.description.clone(),
|
||||
fields,
|
||||
items,
|
||||
elements: Vec::new(),
|
||||
locale: None,
|
||||
template: None,
|
||||
output_template: None,
|
||||
i18n_prefix: None,
|
||||
display_mode: Default::default(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Generate forms with fragments
|
||||
///
|
||||
/// Creates multiple FormDefinition objects: one main form with includes for each fragment,
|
||||
/// and separate forms for each fragment containing only its fields.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// HashMap with "main" key for main form, and fragment names for fragment forms
|
||||
pub fn generate_with_fragments(
|
||||
schema: &NickelSchemaIR,
|
||||
flatten_records: bool,
|
||||
_use_groups: bool,
|
||||
) -> Result<HashMap<String, FormDefinition>> {
|
||||
let mut result = HashMap::new();
|
||||
|
||||
// Get all fragments and ungrouped fields
|
||||
let fragments = schema.fragments();
|
||||
let ungrouped_fields = schema.fields_without_fragment();
|
||||
|
||||
// Generate main form with includes
|
||||
let mut main_items = Vec::new();
|
||||
let mut main_fields = Vec::new();
|
||||
let mut item_order = 0;
|
||||
let mut field_order = 100;
|
||||
|
||||
// Add ungrouped fields to main form
|
||||
if !ungrouped_fields.is_empty() {
|
||||
for field in ungrouped_fields {
|
||||
// Check if this is an array-of-records field
|
||||
if field.is_array_of_records {
|
||||
// Generate fragment for array element
|
||||
let fragment_name = format!("{}_item", field.flat_name);
|
||||
|
||||
if let Some(element_fields) = &field.array_element_fields {
|
||||
let fragment_form = Self::create_fragment_from_fields(
|
||||
&fragment_name,
|
||||
element_fields,
|
||||
flatten_records,
|
||||
schema,
|
||||
)?;
|
||||
|
||||
result.insert(fragment_name.clone(), fragment_form);
|
||||
|
||||
// Generate RepeatingGroup field in main form
|
||||
let repeating_field =
|
||||
Self::create_repeating_group_field(field, &fragment_name, field_order)?;
|
||||
|
||||
main_fields.push(repeating_field);
|
||||
field_order += 1;
|
||||
}
|
||||
} else {
|
||||
// Normal field
|
||||
let form_field =
|
||||
Self::field_ir_to_definition(field, flatten_records, field_order, schema)?;
|
||||
main_fields.push(form_field);
|
||||
field_order += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add includes for each fragment
|
||||
for fragment in &fragments {
|
||||
item_order += 1;
|
||||
main_items.push(DisplayItem {
|
||||
name: format!("{}_group", fragment),
|
||||
item_type: "group".to_string(),
|
||||
title: Some(format_group_title(fragment)),
|
||||
includes: Some(vec![format!("fragments/{}.toml", fragment)]),
|
||||
group: Some(fragment.clone()),
|
||||
order: item_order,
|
||||
content: None,
|
||||
template: None,
|
||||
border_top: None,
|
||||
border_bottom: None,
|
||||
margin_left: None,
|
||||
border_margin_left: None,
|
||||
content_margin_left: None,
|
||||
align: None,
|
||||
when: None,
|
||||
border_top_char: None,
|
||||
border_top_len: None,
|
||||
border_top_l: None,
|
||||
border_top_r: None,
|
||||
border_bottom_char: None,
|
||||
border_bottom_len: None,
|
||||
border_bottom_l: None,
|
||||
border_bottom_r: None,
|
||||
i18n: None,
|
||||
});
|
||||
}
|
||||
|
||||
// Create main form
|
||||
let main_form = FormDefinition {
|
||||
name: format!("{}_main", schema.name),
|
||||
description: schema.description.clone(),
|
||||
fields: main_fields,
|
||||
items: main_items,
|
||||
elements: Vec::new(),
|
||||
locale: None,
|
||||
template: None,
|
||||
output_template: None,
|
||||
i18n_prefix: None,
|
||||
display_mode: Default::default(),
|
||||
};
|
||||
|
||||
result.insert("main".to_string(), main_form);
|
||||
|
||||
// Generate forms for each fragment
|
||||
for fragment in &fragments {
|
||||
let fragment_fields = schema.fields_by_fragment(fragment);
|
||||
|
||||
let mut fields = Vec::new();
|
||||
|
||||
for (field_order, field) in fragment_fields.into_iter().enumerate() {
|
||||
let form_field =
|
||||
Self::field_ir_to_definition(field, flatten_records, field_order, schema)?;
|
||||
fields.push(form_field);
|
||||
}
|
||||
|
||||
let fragment_form = FormDefinition {
|
||||
name: format!("{}_fragment", fragment),
|
||||
description: Some(format!("Fragment: {}", fragment)),
|
||||
fields,
|
||||
items: Vec::new(),
|
||||
elements: Vec::new(),
|
||||
locale: None,
|
||||
template: None,
|
||||
output_template: None,
|
||||
i18n_prefix: None,
|
||||
display_mode: Default::default(),
|
||||
};
|
||||
|
||||
result.insert(fragment.clone(), fragment_form);
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Convert a single NickelFieldIR to a FieldDefinition
|
||||
fn field_ir_to_definition(
|
||||
field: &NickelFieldIR,
|
||||
_flatten_records: bool,
|
||||
order: usize,
|
||||
schema: &NickelSchemaIR,
|
||||
) -> Result<FieldDefinition> {
|
||||
let (field_type, custom_type) = Self::nickel_type_to_field_type(&field.nickel_type)?;
|
||||
|
||||
let prompt = field
|
||||
.doc
|
||||
.clone()
|
||||
.unwrap_or_else(|| format_prompt_from_path(&field.flat_name));
|
||||
|
||||
let default = field.default.as_ref().map(|v| match v {
|
||||
serde_json::Value::String(s) => s.clone(),
|
||||
serde_json::Value::Number(n) => n.to_string(),
|
||||
serde_json::Value::Bool(b) => b.to_string(),
|
||||
serde_json::Value::Null => String::new(),
|
||||
other => other.to_string(),
|
||||
});
|
||||
|
||||
let options = match &field.nickel_type {
|
||||
NickelType::Array(_) => {
|
||||
// Try to extract enum options from array element type or doc
|
||||
Self::extract_enum_options(field)
|
||||
}
|
||||
_ => Vec::new(),
|
||||
};
|
||||
|
||||
// Determine if field is required
|
||||
let required = if field.optional {
|
||||
Some(false)
|
||||
} else {
|
||||
Some(true)
|
||||
};
|
||||
|
||||
// Infer conditional expression from contracts
|
||||
let when_condition = ContractAnalyzer::infer_condition(field, schema);
|
||||
|
||||
Ok(FieldDefinition {
|
||||
// Use alias if present (semantic name), otherwise use flat_name
|
||||
name: field
|
||||
.alias
|
||||
.clone()
|
||||
.unwrap_or_else(|| field.flat_name.clone()),
|
||||
field_type,
|
||||
prompt,
|
||||
default,
|
||||
placeholder: None,
|
||||
options,
|
||||
required,
|
||||
file_extension: None,
|
||||
prefix_text: None,
|
||||
page_size: None,
|
||||
vim_mode: None,
|
||||
custom_type,
|
||||
min_date: None,
|
||||
max_date: None,
|
||||
week_start: None,
|
||||
order,
|
||||
when: when_condition,
|
||||
i18n: None,
|
||||
group: field.group.clone(),
|
||||
nickel_contract: field.contract.clone(),
|
||||
nickel_path: Some(field.path.clone()),
|
||||
nickel_doc: field.doc.clone(),
|
||||
nickel_alias: field.alias.clone(),
|
||||
fragment: None,
|
||||
min_items: None,
|
||||
max_items: None,
|
||||
default_items: None,
|
||||
unique: None,
|
||||
unique_key: None,
|
||||
sensitive: None,
|
||||
encryption_backend: None,
|
||||
encryption_config: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Map a Nickel type to typedialog field type
|
||||
fn nickel_type_to_field_type(nickel_type: &NickelType) -> Result<(FieldType, Option<String>)> {
|
||||
match nickel_type {
|
||||
NickelType::String => Ok((FieldType::Text, None)),
|
||||
NickelType::Number => Ok((FieldType::Custom, Some("f64".to_string()))),
|
||||
NickelType::Bool => Ok((FieldType::Confirm, None)),
|
||||
NickelType::Array(elem_type) => {
|
||||
// Check if this is an array of records (repeating group)
|
||||
if matches!(elem_type.as_ref(), NickelType::Record(_)) {
|
||||
// Array of records -> use RepeatingGroup
|
||||
Ok((FieldType::RepeatingGroup, None))
|
||||
} else {
|
||||
// Simple arrays -> use Editor with JSON for now
|
||||
// (could be enhanced to MultiSelect if options are detected)
|
||||
Ok((FieldType::Editor, Some("json".to_string())))
|
||||
}
|
||||
}
|
||||
NickelType::Record(_) => {
|
||||
// Records are handled by nested field generation
|
||||
Ok((FieldType::Text, None))
|
||||
}
|
||||
NickelType::Custom(type_name) => {
|
||||
// Unknown types map to custom with type name
|
||||
Ok((FieldType::Custom, Some(type_name.clone())))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract enum options from field documentation or array structure
|
||||
/// Returns Vec of SelectOption (with value only, no labels)
|
||||
fn extract_enum_options(field: &NickelFieldIR) -> Vec<SelectOption> {
|
||||
// Check if doc contains "Options: X, Y, Z" pattern
|
||||
if let Some(doc) = &field.doc {
|
||||
if let Some(start) = doc.find("Options:") {
|
||||
let options_str = &doc[start + 8..]; // Skip "Options:"
|
||||
let options: Vec<SelectOption> = options_str
|
||||
.split(',')
|
||||
.map(|s| SelectOption {
|
||||
value: s.trim().to_string(),
|
||||
label: None,
|
||||
})
|
||||
.filter(|opt| !opt.value.is_empty())
|
||||
.collect();
|
||||
if !options.is_empty() {
|
||||
return options;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For now, don't try to extract from array structure unless we have more info
|
||||
Vec::new()
|
||||
}
|
||||
|
||||
/// Create a FormDefinition fragment from array element fields
|
||||
fn create_fragment_from_fields(
|
||||
name: &str,
|
||||
element_fields: &[NickelFieldIR],
|
||||
flatten_records: bool,
|
||||
schema: &NickelSchemaIR,
|
||||
) -> Result<FormDefinition> {
|
||||
let mut fields = Vec::new();
|
||||
|
||||
// Generate FieldDefinition for each element field
|
||||
for (order, elem_field) in element_fields.iter().enumerate() {
|
||||
let field_def =
|
||||
Self::field_ir_to_definition(elem_field, flatten_records, order, schema)?;
|
||||
fields.push(field_def);
|
||||
}
|
||||
|
||||
Ok(FormDefinition {
|
||||
name: name.to_string(),
|
||||
description: Some(format!("Array element definition for {}", name)),
|
||||
fields,
|
||||
items: Vec::new(),
|
||||
elements: Vec::new(),
|
||||
locale: None,
|
||||
template: None,
|
||||
output_template: None,
|
||||
i18n_prefix: None,
|
||||
display_mode: Default::default(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a RepeatingGroup FieldDefinition pointing to a fragment
|
||||
fn create_repeating_group_field(
|
||||
field: &NickelFieldIR,
|
||||
fragment_name: &str,
|
||||
order: usize,
|
||||
) -> Result<FieldDefinition> {
|
||||
let prompt = field
|
||||
.doc
|
||||
.as_ref()
|
||||
.map(|d| d.lines().next().unwrap_or("").to_string())
|
||||
.unwrap_or_else(|| {
|
||||
field
|
||||
.alias
|
||||
.clone()
|
||||
.unwrap_or_else(|| field.flat_name.clone())
|
||||
});
|
||||
|
||||
Ok(FieldDefinition {
|
||||
name: field
|
||||
.alias
|
||||
.clone()
|
||||
.unwrap_or_else(|| field.flat_name.clone()),
|
||||
field_type: FieldType::RepeatingGroup,
|
||||
prompt,
|
||||
default: None,
|
||||
placeholder: None,
|
||||
options: Vec::new(),
|
||||
required: Some(!field.optional),
|
||||
file_extension: None,
|
||||
prefix_text: None,
|
||||
page_size: None,
|
||||
vim_mode: None,
|
||||
custom_type: None,
|
||||
min_date: None,
|
||||
max_date: None,
|
||||
week_start: None,
|
||||
order,
|
||||
when: None,
|
||||
i18n: None,
|
||||
group: field.group.clone(),
|
||||
nickel_contract: field.contract.clone(),
|
||||
nickel_path: Some(field.path.clone()),
|
||||
nickel_doc: field.doc.clone(),
|
||||
nickel_alias: field.alias.clone(),
|
||||
fragment: Some(format!("fragments/{}.toml", fragment_name)),
|
||||
min_items: if field.optional { Some(0) } else { Some(1) },
|
||||
max_items: Some(10), // Default limit
|
||||
default_items: Some(if field.optional { 0 } else { 1 }),
|
||||
unique: None,
|
||||
unique_key: None,
|
||||
sensitive: None,
|
||||
encryption_backend: None,
|
||||
encryption_config: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Format a group title from group name
|
||||
fn format_group_title(group: &str) -> String {
|
||||
// Convert snake_case or kebab-case to Title Case
|
||||
group
|
||||
.split(['_', '-'])
|
||||
.map(|word| {
|
||||
let mut chars = word.chars();
|
||||
match chars.next() {
|
||||
None => String::new(),
|
||||
Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join(" ")
|
||||
}
|
||||
|
||||
/// Format a prompt from field name
|
||||
fn format_prompt_from_path(flat_name: &str) -> String {
|
||||
// Convert snake_case to Title Case
|
||||
flat_name
|
||||
.split('_')
|
||||
.map(|word| {
|
||||
let mut chars = word.chars();
|
||||
match chars.next() {
|
||||
None => String::new(),
|
||||
Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join(" ")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_generate_simple_schema() {
|
||||
let schema = NickelSchemaIR {
|
||||
name: "test_schema".to_string(),
|
||||
description: Some("A test schema".to_string()),
|
||||
fields: vec![
|
||||
NickelFieldIR {
|
||||
path: vec!["name".to_string()],
|
||||
flat_name: "name".to_string(),
|
||||
alias: None,
|
||||
nickel_type: NickelType::String,
|
||||
doc: Some("User full name".to_string()),
|
||||
default: Some(json!("Alice")),
|
||||
optional: false,
|
||||
contract: Some("String | std.string.NonEmpty".to_string()),
|
||||
contract_call: None,
|
||||
group: None,
|
||||
fragment_marker: None,
|
||||
is_array_of_records: false,
|
||||
array_element_fields: None,
|
||||
},
|
||||
NickelFieldIR {
|
||||
path: vec!["age".to_string()],
|
||||
flat_name: "age".to_string(),
|
||||
alias: None,
|
||||
nickel_type: NickelType::Number,
|
||||
doc: Some("User age".to_string()),
|
||||
default: None,
|
||||
optional: true,
|
||||
contract: None,
|
||||
contract_call: None,
|
||||
group: None,
|
||||
fragment_marker: None,
|
||||
is_array_of_records: false,
|
||||
array_element_fields: None,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
let form = TomlGenerator::generate(&schema, false, false).unwrap();
|
||||
assert_eq!(form.name, "test_schema");
|
||||
assert_eq!(form.fields.len(), 2);
|
||||
|
||||
// Check first field
|
||||
assert_eq!(form.fields[0].name, "name");
|
||||
assert_eq!(form.fields[0].field_type, FieldType::Text);
|
||||
assert_eq!(form.fields[0].required, Some(true));
|
||||
assert_eq!(
|
||||
form.fields[0].nickel_contract,
|
||||
Some("String | std.string.NonEmpty".to_string())
|
||||
);
|
||||
|
||||
// Check second field
|
||||
assert_eq!(form.fields[1].name, "age");
|
||||
assert_eq!(form.fields[1].field_type, FieldType::Custom);
|
||||
assert_eq!(form.fields[1].custom_type, Some("f64".to_string()));
|
||||
assert_eq!(form.fields[1].required, Some(false));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_with_groups() {
|
||||
let schema = NickelSchemaIR {
|
||||
name: "grouped_schema".to_string(),
|
||||
description: None,
|
||||
fields: vec![
|
||||
NickelFieldIR {
|
||||
path: vec!["user_name".to_string()],
|
||||
flat_name: "user_name".to_string(),
|
||||
alias: None,
|
||||
nickel_type: NickelType::String,
|
||||
doc: Some("User name".to_string()),
|
||||
default: None,
|
||||
optional: false,
|
||||
contract: None,
|
||||
contract_call: None,
|
||||
group: Some("user".to_string()),
|
||||
fragment_marker: None,
|
||||
is_array_of_records: false,
|
||||
array_element_fields: None,
|
||||
},
|
||||
NickelFieldIR {
|
||||
path: vec!["settings_theme".to_string()],
|
||||
flat_name: "settings_theme".to_string(),
|
||||
alias: None,
|
||||
nickel_type: NickelType::String,
|
||||
doc: Some("Theme preference".to_string()),
|
||||
default: Some(json!("dark")),
|
||||
optional: false,
|
||||
contract: None,
|
||||
contract_call: None,
|
||||
group: Some("settings".to_string()),
|
||||
fragment_marker: None,
|
||||
is_array_of_records: false,
|
||||
array_element_fields: None,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
let form = TomlGenerator::generate(&schema, false, true).unwrap();
|
||||
|
||||
// Should have display items for groups
|
||||
assert!(form.items.len() > 0);
|
||||
|
||||
// Check fields are grouped
|
||||
assert_eq!(form.fields[0].group, Some("user".to_string()));
|
||||
assert_eq!(form.fields[1].group, Some("settings".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nickel_type_to_field_type() {
|
||||
let (field_type, custom_type) =
|
||||
TomlGenerator::nickel_type_to_field_type(&NickelType::String).unwrap();
|
||||
assert_eq!(field_type, FieldType::Text);
|
||||
assert_eq!(custom_type, None);
|
||||
|
||||
let (field_type, custom_type) =
|
||||
TomlGenerator::nickel_type_to_field_type(&NickelType::Number).unwrap();
|
||||
assert_eq!(field_type, FieldType::Custom);
|
||||
assert_eq!(custom_type, Some("f64".to_string()));
|
||||
|
||||
let (field_type, custom_type) =
|
||||
TomlGenerator::nickel_type_to_field_type(&NickelType::Bool).unwrap();
|
||||
assert_eq!(field_type, FieldType::Confirm);
|
||||
assert_eq!(custom_type, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_group_title() {
|
||||
assert_eq!(format_group_title("user"), "User");
|
||||
assert_eq!(format_group_title("user_settings"), "User Settings");
|
||||
assert_eq!(format_group_title("api-config"), "Api Config");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_prompt_from_path() {
|
||||
assert_eq!(format_prompt_from_path("name"), "Name");
|
||||
assert_eq!(format_prompt_from_path("user_name"), "User Name");
|
||||
assert_eq!(format_prompt_from_path("first_name"), "First Name");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_enum_options() {
|
||||
let field = NickelFieldIR {
|
||||
path: vec!["status".to_string()],
|
||||
flat_name: "status".to_string(),
|
||||
alias: None,
|
||||
nickel_type: NickelType::Array(Box::new(NickelType::String)),
|
||||
doc: Some("Status. Options: pending, active, completed".to_string()),
|
||||
default: None,
|
||||
optional: false,
|
||||
contract: None,
|
||||
contract_call: None,
|
||||
group: None,
|
||||
fragment_marker: None,
|
||||
is_array_of_records: false,
|
||||
array_element_fields: None,
|
||||
};
|
||||
|
||||
let options = TomlGenerator::extract_enum_options(&field);
|
||||
assert_eq!(options.len(), 3);
|
||||
assert_eq!(options[0].value, "pending");
|
||||
assert_eq!(options[1].value, "active");
|
||||
assert_eq!(options[2].value, "completed");
|
||||
// Labels should be None for options extracted from doc strings
|
||||
assert_eq!(options[0].label, None);
|
||||
assert_eq!(options[1].label, None);
|
||||
assert_eq!(options[2].label, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_value_conversion() {
|
||||
let field = NickelFieldIR {
|
||||
path: vec!["count".to_string()],
|
||||
flat_name: "count".to_string(),
|
||||
alias: None,
|
||||
nickel_type: NickelType::Number,
|
||||
doc: None,
|
||||
default: Some(json!(42)),
|
||||
optional: false,
|
||||
contract: None,
|
||||
contract_call: None,
|
||||
group: None,
|
||||
fragment_marker: None,
|
||||
is_array_of_records: false,
|
||||
array_element_fields: None,
|
||||
};
|
||||
|
||||
let schema = NickelSchemaIR {
|
||||
name: "test".to_string(),
|
||||
description: None,
|
||||
fields: vec![field.clone()],
|
||||
};
|
||||
|
||||
let form_field = TomlGenerator::field_ir_to_definition(&field, false, 0, &schema).unwrap();
|
||||
assert_eq!(form_field.default, Some("42".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_array_of_records_detection_and_fragment_generation() {
|
||||
// Create a field with Array(Record(...)) type
|
||||
let tracker_field = NickelFieldIR {
|
||||
path: vec!["bind_address".to_string()],
|
||||
flat_name: "bind_address".to_string(),
|
||||
alias: None,
|
||||
nickel_type: NickelType::String,
|
||||
doc: Some("Bind Address".to_string()),
|
||||
default: Some(json!("0.0.0.0:6969")),
|
||||
optional: false,
|
||||
contract: None,
|
||||
contract_call: None,
|
||||
group: None,
|
||||
fragment_marker: None,
|
||||
is_array_of_records: false,
|
||||
array_element_fields: None,
|
||||
};
|
||||
|
||||
let udp_trackers_field = NickelFieldIR {
|
||||
path: vec!["udp_trackers".to_string()],
|
||||
flat_name: "udp_trackers".to_string(),
|
||||
alias: Some("trackers".to_string()),
|
||||
nickel_type: NickelType::Array(Box::new(NickelType::Record(vec![
|
||||
tracker_field.clone()
|
||||
]))),
|
||||
doc: Some("UDP Tracker Listeners".to_string()),
|
||||
default: None,
|
||||
optional: true,
|
||||
contract: None,
|
||||
contract_call: None,
|
||||
group: None,
|
||||
fragment_marker: None,
|
||||
is_array_of_records: true,
|
||||
array_element_fields: Some(vec![tracker_field.clone()]),
|
||||
};
|
||||
|
||||
let schema = NickelSchemaIR {
|
||||
name: "tracker_config".to_string(),
|
||||
description: Some("Torrust Tracker Configuration".to_string()),
|
||||
fields: vec![udp_trackers_field.clone()],
|
||||
};
|
||||
|
||||
// Test fragment generation
|
||||
let forms = TomlGenerator::generate_with_fragments(&schema, true, false).unwrap();
|
||||
|
||||
// Should have main form + fragment form
|
||||
assert!(forms.contains_key("main"));
|
||||
assert!(
|
||||
forms.contains_key("udp_trackers_item"),
|
||||
"Should generate fragment for array element"
|
||||
);
|
||||
|
||||
// Check main form has RepeatingGroup field
|
||||
let main_form = forms.get("main").unwrap();
|
||||
assert_eq!(main_form.fields.len(), 1);
|
||||
assert_eq!(main_form.fields[0].field_type, FieldType::RepeatingGroup);
|
||||
assert_eq!(main_form.fields[0].name, "trackers"); // Uses alias
|
||||
assert_eq!(
|
||||
main_form.fields[0].fragment,
|
||||
Some("fragments/udp_trackers_item.toml".to_string())
|
||||
);
|
||||
assert_eq!(main_form.fields[0].min_items, Some(0)); // Optional
|
||||
assert_eq!(main_form.fields[0].max_items, Some(10));
|
||||
assert_eq!(main_form.fields[0].default_items, Some(0));
|
||||
|
||||
// Check fragment form has element fields
|
||||
let fragment_form = forms.get("udp_trackers_item").unwrap();
|
||||
assert_eq!(fragment_form.fields.len(), 1);
|
||||
assert_eq!(fragment_form.fields[0].name, "bind_address");
|
||||
assert_eq!(fragment_form.fields[0].field_type, FieldType::Text);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nickel_type_array_of_records_maps_to_repeating_group() {
|
||||
// Test that Array(Record(...)) maps to RepeatingGroup
|
||||
let record_type = NickelType::Record(vec![]);
|
||||
let array_of_records = NickelType::Array(Box::new(record_type));
|
||||
|
||||
let (field_type, custom_type) =
|
||||
TomlGenerator::nickel_type_to_field_type(&array_of_records).unwrap();
|
||||
assert_eq!(field_type, FieldType::RepeatingGroup);
|
||||
assert_eq!(custom_type, None);
|
||||
|
||||
// Test that simple arrays still map to Editor
|
||||
let simple_array = NickelType::Array(Box::new(NickelType::String));
|
||||
let (field_type, custom_type) =
|
||||
TomlGenerator::nickel_type_to_field_type(&simple_array).unwrap();
|
||||
assert_eq!(field_type, FieldType::Editor);
|
||||
assert_eq!(custom_type, Some("json".to_string()));
|
||||
}
|
||||
}
|
||||
@ -3,7 +3,7 @@
|
||||
//! Provides high-level functions for all prompt types with automatic
|
||||
//! fallback to stdin when interactive mode isn't available.
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
use crate::error::{ErrorWrapper, Result};
|
||||
use chrono::{NaiveDate, Weekday};
|
||||
use inquire::{
|
||||
Confirm, DateSelect, Editor as InquireEditor, MultiSelect, Password, PasswordDisplayMode,
|
||||
@ -370,7 +370,7 @@ pub fn custom(prompt: &str, type_name: &str, default: Option<&str>) -> Result<St
|
||||
if validator(&result) {
|
||||
Ok(result)
|
||||
} else {
|
||||
Err(Error::validation_failed(format!(
|
||||
Err(ErrorWrapper::validation_failed(format!(
|
||||
"Invalid input for type {}",
|
||||
type_name
|
||||
)))
|
||||
@ -424,7 +424,7 @@ fn stdin_confirm(prompt: &str, default: Option<bool>) -> Result<bool> {
|
||||
"y" | "yes" | "true" => Ok(true),
|
||||
"n" | "no" | "false" => Ok(false),
|
||||
"" => Ok(default.unwrap_or(false)),
|
||||
_ => Err(Error::validation_failed("Please answer yes or no")),
|
||||
_ => Err(ErrorWrapper::validation_failed("Please answer yes or no")),
|
||||
}
|
||||
}
|
||||
|
||||
@ -528,16 +528,18 @@ fn open_editor_with_temp_file(
|
||||
tempfile::Builder::new()
|
||||
.suffix(&format!(".{}", ext))
|
||||
.tempfile()
|
||||
.map_err(Error::io)?
|
||||
.map_err(ErrorWrapper::io)?
|
||||
} else {
|
||||
NamedTempFile::new().map_err(Error::io)?
|
||||
NamedTempFile::new().map_err(ErrorWrapper::io)?
|
||||
};
|
||||
|
||||
// Write prefix text to temp file
|
||||
if let Some(text) = prefix_text {
|
||||
let content = text.replace("\\n", "\n");
|
||||
temp_file.write_all(content.as_bytes()).map_err(Error::io)?;
|
||||
temp_file.flush().map_err(Error::io)?;
|
||||
temp_file
|
||||
.write_all(content.as_bytes())
|
||||
.map_err(ErrorWrapper::io)?;
|
||||
temp_file.flush().map_err(ErrorWrapper::io)?;
|
||||
}
|
||||
|
||||
let path = temp_file.path().to_path_buf();
|
||||
@ -557,17 +559,17 @@ fn open_editor_with_temp_file(
|
||||
let status = Command::new(&editor)
|
||||
.arg(&path)
|
||||
.status()
|
||||
.map_err(Error::io)?;
|
||||
.map_err(ErrorWrapper::io)?;
|
||||
|
||||
if !status.success() {
|
||||
return Err(Error::validation_failed(format!(
|
||||
return Err(ErrorWrapper::validation_failed(format!(
|
||||
"Editor '{}' exited with error code",
|
||||
editor
|
||||
)));
|
||||
}
|
||||
|
||||
// Read the edited content back
|
||||
let content = std::fs::read_to_string(&path).map_err(Error::io)?;
|
||||
let content = std::fs::read_to_string(&path).map_err(ErrorWrapper::io)?;
|
||||
|
||||
Ok(content)
|
||||
}
|
||||
@ -633,11 +635,9 @@ where
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_stdlib_integration() {
|
||||
// Just verify the types and signatures are correct
|
||||
let _: Result<String> = Ok("test".to_string());
|
||||
drop(Ok::<String, crate::error::ErrorWrapper>("test".to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
@ -5,7 +5,7 @@ mod filters;
|
||||
|
||||
pub use context::TemplateContextBuilder;
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
use crate::error::{ErrorWrapper, Result};
|
||||
use std::path::Path;
|
||||
use tera::Tera;
|
||||
|
||||
@ -24,8 +24,9 @@ impl TemplateEngine {
|
||||
let pattern = path.join("**/*.tera");
|
||||
let pattern_str = pattern.to_str().unwrap_or("templates/**/*.tera");
|
||||
|
||||
Tera::new(pattern_str)
|
||||
.map_err(|e| Error::template_failed(format!("Failed to initialize Tera: {}", e)))?
|
||||
Tera::new(pattern_str).map_err(|e| {
|
||||
ErrorWrapper::template_failed(format!("Failed to initialize Tera: {}", e))
|
||||
})?
|
||||
} else {
|
||||
Tera::default()
|
||||
};
|
||||
@ -39,14 +40,14 @@ impl TemplateEngine {
|
||||
/// Add a template string for inline template rendering
|
||||
pub fn add_template(&mut self, name: &str, content: &str) -> Result<()> {
|
||||
self.tera.add_raw_template(name, content).map_err(|e| {
|
||||
Error::template_failed(format!("Failed to add template '{}': {}", name, e))
|
||||
ErrorWrapper::template_failed(format!("Failed to add template '{}': {}", name, e))
|
||||
})
|
||||
}
|
||||
|
||||
/// Render a template by name with context
|
||||
pub fn render(&self, template_name: &str, context: &tera::Context) -> Result<String> {
|
||||
self.tera.render(template_name, context).map_err(|e| {
|
||||
Error::template_failed(format!(
|
||||
ErrorWrapper::template_failed(format!(
|
||||
"Failed to render template '{}': {}",
|
||||
template_name, e
|
||||
))
|
||||
@ -56,7 +57,7 @@ impl TemplateEngine {
|
||||
/// Render a template string directly
|
||||
pub fn render_str(&self, template: &str, context: &tera::Context) -> Result<String> {
|
||||
Tera::one_off(template, context, false)
|
||||
.map_err(|e| Error::template_failed(format!("Failed to render template: {}", e)))
|
||||
.map_err(|e| ErrorWrapper::template_failed(format!("Failed to render template: {}", e)))
|
||||
}
|
||||
|
||||
/// Check if a template exists
|
||||
|
||||
@ -10,12 +10,16 @@ mod encryption_tests {
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use typedialog_core::form_parser::{FieldDefinition, FieldType};
|
||||
use typedialog_core::helpers::{EncryptionContext, format_results_secure, transform_results};
|
||||
use typedialog_core::helpers::{format_results_secure, transform_results, EncryptionContext};
|
||||
|
||||
fn make_field(name: &str, sensitive: bool) -> FieldDefinition {
|
||||
FieldDefinition {
|
||||
name: name.to_string(),
|
||||
field_type: if sensitive { FieldType::Password } else { FieldType::Text },
|
||||
field_type: if sensitive {
|
||||
FieldType::Password
|
||||
} else {
|
||||
FieldType::Text
|
||||
},
|
||||
prompt: format!("{}: ", name),
|
||||
default: None,
|
||||
placeholder: None,
|
||||
@ -148,7 +152,10 @@ mod encryption_tests {
|
||||
encryption_config: None,
|
||||
};
|
||||
|
||||
assert!(field.is_sensitive(), "Password field should auto-detect as sensitive");
|
||||
assert!(
|
||||
field.is_sensitive(),
|
||||
"Password field should auto-detect as sensitive"
|
||||
);
|
||||
|
||||
let context = EncryptionContext::redact_only();
|
||||
let transformed = transform_results(&results, &[field], &context, None).unwrap();
|
||||
@ -198,7 +205,10 @@ mod encryption_tests {
|
||||
encryption_config: None,
|
||||
};
|
||||
|
||||
assert!(!field.is_sensitive(), "Explicit sensitive=false should override field type");
|
||||
assert!(
|
||||
!field.is_sensitive(),
|
||||
"Explicit sensitive=false should override field type"
|
||||
);
|
||||
|
||||
let context = EncryptionContext::redact_only();
|
||||
let transformed = transform_results(&results, &[field], &context, None).unwrap();
|
||||
@ -256,8 +266,10 @@ mod encryption_tests {
|
||||
// Should fail with unknown backend error
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(err.to_string().contains("Unknown encryption backend") ||
|
||||
err.to_string().contains("unknown_backend"));
|
||||
assert!(
|
||||
err.to_string().contains("Unknown encryption backend")
|
||||
|| err.to_string().contains("unknown_backend")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -290,14 +302,14 @@ mod encryption_tests {
|
||||
|
||||
#[cfg(feature = "encryption")]
|
||||
mod age_roundtrip_tests {
|
||||
use age::secrecy::ExposeSecret;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use typedialog_core::form_parser::{FieldDefinition, FieldType};
|
||||
use typedialog_core::helpers::{EncryptionContext, transform_results};
|
||||
use tempfile::TempDir;
|
||||
use typedialog_core::encrypt::backend::age::AgeBackend;
|
||||
use typedialog_core::encrypt::EncryptionBackend;
|
||||
use age::secrecy::ExposeSecret;
|
||||
use typedialog_core::form_parser::{FieldDefinition, FieldType};
|
||||
use typedialog_core::helpers::{transform_results, EncryptionContext};
|
||||
|
||||
fn make_password_field(name: &str) -> FieldDefinition {
|
||||
FieldDefinition {
|
||||
@ -337,8 +349,7 @@ mod age_roundtrip_tests {
|
||||
}
|
||||
|
||||
fn create_test_age_backend() -> std::result::Result<(AgeBackend, TempDir), String> {
|
||||
let temp_dir = TempDir::new()
|
||||
.map_err(|e| format!("Failed to create temp dir: {}", e))?;
|
||||
let temp_dir = TempDir::new().map_err(|e| format!("Failed to create temp dir: {}", e))?;
|
||||
|
||||
// Generate test key pair
|
||||
let secret_key = age::x25519::Identity::generate();
|
||||
@ -400,10 +411,15 @@ mod age_roundtrip_tests {
|
||||
|
||||
// Encrypt twice with same plaintext
|
||||
let cipher1 = backend.encrypt(plaintext).expect("First encryption failed");
|
||||
let cipher2 = backend.encrypt(plaintext).expect("Second encryption failed");
|
||||
let cipher2 = backend
|
||||
.encrypt(plaintext)
|
||||
.expect("Second encryption failed");
|
||||
|
||||
// Should produce different ciphertexts (different nonces)
|
||||
assert_ne!(cipher1, cipher2, "Age should produce different ciphertexts for same plaintext");
|
||||
assert_ne!(
|
||||
cipher1, cipher2,
|
||||
"Age should produce different ciphertexts for same plaintext"
|
||||
);
|
||||
|
||||
// Both should decrypt to original
|
||||
assert_eq!(
|
||||
@ -435,7 +451,10 @@ mod age_roundtrip_tests {
|
||||
let decrypted = backend.decrypt(&ciphertext).expect("Decryption failed");
|
||||
|
||||
// Verify roundtrip
|
||||
assert_eq!(decrypted, plaintext, "Roundtrip encryption/decryption should preserve plaintext");
|
||||
assert_eq!(
|
||||
decrypted, plaintext,
|
||||
"Roundtrip encryption/decryption should preserve plaintext"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -444,7 +463,9 @@ mod age_roundtrip_tests {
|
||||
|
||||
let plaintext = "";
|
||||
|
||||
let ciphertext = backend.encrypt(plaintext).expect("Encryption of empty string failed");
|
||||
let ciphertext = backend
|
||||
.encrypt(plaintext)
|
||||
.expect("Encryption of empty string failed");
|
||||
let decrypted = backend.decrypt(&ciphertext).expect("Decryption failed");
|
||||
|
||||
assert_eq!(decrypted, plaintext, "Empty string roundtrip should work");
|
||||
@ -456,10 +477,15 @@ mod age_roundtrip_tests {
|
||||
|
||||
let plaintext = "password_with_emoji_🔐_and_unicode_ñ";
|
||||
|
||||
let ciphertext = backend.encrypt(plaintext).expect("Encryption of unicode failed");
|
||||
let ciphertext = backend
|
||||
.encrypt(plaintext)
|
||||
.expect("Encryption of unicode failed");
|
||||
let decrypted = backend.decrypt(&ciphertext).expect("Decryption failed");
|
||||
|
||||
assert_eq!(decrypted, plaintext, "Unicode plaintext roundtrip should work");
|
||||
assert_eq!(
|
||||
decrypted, plaintext,
|
||||
"Unicode plaintext roundtrip should work"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -468,10 +494,15 @@ mod age_roundtrip_tests {
|
||||
|
||||
let plaintext = "x".repeat(10000);
|
||||
|
||||
let ciphertext = backend.encrypt(&plaintext).expect("Encryption of large value failed");
|
||||
let ciphertext = backend
|
||||
.encrypt(&plaintext)
|
||||
.expect("Encryption of large value failed");
|
||||
let decrypted = backend.decrypt(&ciphertext).expect("Decryption failed");
|
||||
|
||||
assert_eq!(decrypted, plaintext, "Large plaintext roundtrip should work");
|
||||
assert_eq!(
|
||||
decrypted, plaintext,
|
||||
"Large plaintext roundtrip should work"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -490,6 +521,9 @@ mod age_roundtrip_tests {
|
||||
let (backend, _temp) = create_test_age_backend().unwrap();
|
||||
|
||||
let is_available = backend.is_available().expect("is_available check failed");
|
||||
assert!(is_available, "Age backend should report itself as available");
|
||||
assert!(
|
||||
is_available,
|
||||
"Age backend should report itself as available"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -148,7 +148,7 @@ fn test_nested_schema_with_flatten() {
|
||||
let form = TomlGenerator::generate(&schema, false, true).expect("Form generation failed");
|
||||
|
||||
// Verify groups are created
|
||||
assert!(form.items.len() > 0);
|
||||
assert!(!form.items.is_empty());
|
||||
|
||||
// Verify fields have groups
|
||||
let server_hostname = form
|
||||
@ -194,7 +194,7 @@ fn test_array_field_serialization() {
|
||||
fragment_marker: None,
|
||||
is_array_of_records: false,
|
||||
array_element_fields: None,
|
||||
encryption_metadata: None,
|
||||
encryption_metadata: None,
|
||||
}],
|
||||
};
|
||||
|
||||
@ -273,7 +273,7 @@ fn test_form_definition_from_schema_ir() {
|
||||
fragment_marker: None,
|
||||
is_array_of_records: false,
|
||||
array_element_fields: None,
|
||||
encryption_metadata: None,
|
||||
encryption_metadata: None,
|
||||
}],
|
||||
};
|
||||
|
||||
@ -462,7 +462,7 @@ fn test_enum_options_extraction_from_doc() {
|
||||
fragment_marker: None,
|
||||
is_array_of_records: false,
|
||||
array_element_fields: None,
|
||||
encryption_metadata: None,
|
||||
encryption_metadata: None,
|
||||
};
|
||||
|
||||
let schema = NickelSchemaIR {
|
||||
@ -963,7 +963,7 @@ fn test_torrust_tracker_schema_generation() {
|
||||
|
||||
// Export defaults to JSON (provides structure with real data including arrays)
|
||||
let output = Command::new("nickel")
|
||||
.args(&["export", "--format", "json", defaults_path])
|
||||
.args(["export", "--format", "json", defaults_path])
|
||||
.output()
|
||||
.expect("Failed to execute nickel export");
|
||||
|
||||
@ -1097,17 +1097,18 @@ fn test_encryption_metadata_in_nickel_field() {
|
||||
fragment_marker: None,
|
||||
is_array_of_records: false,
|
||||
array_element_fields: None,
|
||||
encryption_metadata: Some(
|
||||
typedialog_core::nickel::EncryptionMetadata {
|
||||
sensitive: true,
|
||||
backend: Some("age".to_string()),
|
||||
key: None,
|
||||
}
|
||||
),
|
||||
encryption_metadata: Some(typedialog_core::nickel::EncryptionMetadata {
|
||||
sensitive: true,
|
||||
backend: Some("age".to_string()),
|
||||
key: None,
|
||||
}),
|
||||
};
|
||||
|
||||
assert!(field.encryption_metadata.is_some());
|
||||
assert_eq!(field.encryption_metadata.as_ref().unwrap().backend, Some("age".to_string()));
|
||||
assert_eq!(
|
||||
field.encryption_metadata.as_ref().unwrap().backend,
|
||||
Some("age".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -1141,19 +1142,19 @@ fn test_encryption_metadata_to_field_definition() {
|
||||
doc: Some("User password - encrypted".to_string()),
|
||||
default: None,
|
||||
optional: false,
|
||||
contract: Some("String | Sensitive Backend=\"age\" Key=\"~/.age/key.txt\"".to_string()),
|
||||
contract: Some(
|
||||
"String | Sensitive Backend=\"age\" Key=\"~/.age/key.txt\"".to_string(),
|
||||
),
|
||||
contract_call: None,
|
||||
group: None,
|
||||
fragment_marker: None,
|
||||
is_array_of_records: false,
|
||||
array_element_fields: None,
|
||||
encryption_metadata: Some(
|
||||
typedialog_core::nickel::EncryptionMetadata {
|
||||
sensitive: true,
|
||||
backend: Some("age".to_string()),
|
||||
key: Some("~/.age/key.txt".to_string()),
|
||||
}
|
||||
),
|
||||
encryption_metadata: Some(typedialog_core::nickel::EncryptionMetadata {
|
||||
sensitive: true,
|
||||
backend: Some("age".to_string()),
|
||||
key: Some("~/.age/key.txt".to_string()),
|
||||
}),
|
||||
},
|
||||
],
|
||||
};
|
||||
@ -1162,18 +1163,28 @@ fn test_encryption_metadata_to_field_definition() {
|
||||
let form = TomlGenerator::generate(&schema, false, false).expect("Form generation failed");
|
||||
|
||||
// Find password field
|
||||
let password_field = form.fields.iter().find(|f| f.name == "password").expect("Password field not found");
|
||||
let password_field = form
|
||||
.fields
|
||||
.iter()
|
||||
.find(|f| f.name == "password")
|
||||
.expect("Password field not found");
|
||||
|
||||
// Verify encryption metadata mapped to FieldDefinition
|
||||
assert_eq!(password_field.sensitive, Some(true));
|
||||
assert_eq!(password_field.encryption_backend, Some("age".to_string()));
|
||||
assert_eq!(password_field.encryption_config.as_ref().map(|c| c.get("key")).flatten(), Some(&"~/.age/key.txt".to_string()));
|
||||
assert_eq!(
|
||||
password_field
|
||||
.encryption_config
|
||||
.as_ref()
|
||||
.and_then(|c| c.get("key")),
|
||||
Some(&"~/.age/key.txt".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encryption_roundtrip_with_redaction() {
|
||||
use typedialog_core::helpers::{EncryptionContext, transform_results};
|
||||
use typedialog_core::form_parser::FieldType;
|
||||
use typedialog_core::helpers::{transform_results, EncryptionContext};
|
||||
|
||||
// Create form with sensitive fields
|
||||
let mut form_results = HashMap::new();
|
||||
@ -1292,19 +1303,31 @@ fn test_encryption_roundtrip_with_redaction() {
|
||||
.expect("Redaction failed");
|
||||
|
||||
// Verify redaction - extract string values idomatically
|
||||
let username = redacted.get("username").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let password = redacted.get("password").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let api_key = redacted.get("api_key").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let username = redacted
|
||||
.get("username")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
let password = redacted
|
||||
.get("password")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
let api_key = redacted
|
||||
.get("api_key")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
assert_eq!(username, "alice", "Non-sensitive field should not be redacted");
|
||||
assert_eq!(
|
||||
username, "alice",
|
||||
"Non-sensitive field should not be redacted"
|
||||
);
|
||||
assert_eq!(password, "[REDACTED]", "Sensitive field should be redacted");
|
||||
assert_eq!(api_key, "[REDACTED]", "Sensitive field should be redacted");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encryption_auto_detection_from_field_type() {
|
||||
use typedialog_core::helpers::{EncryptionContext, transform_results};
|
||||
use typedialog_core::form_parser::FieldType;
|
||||
use typedialog_core::helpers::{transform_results, EncryptionContext};
|
||||
|
||||
let mut results = HashMap::new();
|
||||
results.insert("password".to_string(), json!("secret_value"));
|
||||
@ -1340,25 +1363,31 @@ fn test_encryption_auto_detection_from_field_type() {
|
||||
default_items: None,
|
||||
unique: None,
|
||||
unique_key: None,
|
||||
sensitive: None, // Not explicitly set
|
||||
sensitive: None, // Not explicitly set
|
||||
encryption_backend: None,
|
||||
encryption_config: None,
|
||||
};
|
||||
|
||||
assert!(field.is_sensitive(), "Password field should auto-detect as sensitive");
|
||||
assert!(
|
||||
field.is_sensitive(),
|
||||
"Password field should auto-detect as sensitive"
|
||||
);
|
||||
|
||||
let context = EncryptionContext::redact_only();
|
||||
let transformed = transform_results(&results, &[field], &context, None)
|
||||
.expect("Transform failed");
|
||||
let transformed =
|
||||
transform_results(&results, &[field], &context, None).expect("Transform failed");
|
||||
|
||||
let password_val = transformed.get("password").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let password_val = transformed
|
||||
.get("password")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
assert_eq!(password_val, "[REDACTED]");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sensitive_field_explicit_override() {
|
||||
use typedialog_core::helpers::{EncryptionContext, transform_results};
|
||||
use typedialog_core::form_parser::FieldType;
|
||||
use typedialog_core::helpers::{transform_results, EncryptionContext};
|
||||
|
||||
let mut results = HashMap::new();
|
||||
results.insert("password".to_string(), json!("visible_value"));
|
||||
@ -1394,26 +1423,32 @@ fn test_sensitive_field_explicit_override() {
|
||||
default_items: None,
|
||||
unique: None,
|
||||
unique_key: None,
|
||||
sensitive: Some(false), // Explicitly override
|
||||
sensitive: Some(false), // Explicitly override
|
||||
encryption_backend: None,
|
||||
encryption_config: None,
|
||||
};
|
||||
|
||||
assert!(!field.is_sensitive(), "Explicit sensitive=false should override field type");
|
||||
assert!(
|
||||
!field.is_sensitive(),
|
||||
"Explicit sensitive=false should override field type"
|
||||
);
|
||||
|
||||
let context = EncryptionContext::redact_only();
|
||||
let transformed = transform_results(&results, &[field], &context, None)
|
||||
.expect("Transform failed");
|
||||
let transformed =
|
||||
transform_results(&results, &[field], &context, None).expect("Transform failed");
|
||||
|
||||
// Should NOT be redacted
|
||||
let password_val = transformed.get("password").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let password_val = transformed
|
||||
.get("password")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
assert_eq!(password_val, "visible_value");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mixed_sensitive_and_non_sensitive_fields() {
|
||||
use typedialog_core::helpers::{EncryptionContext, transform_results};
|
||||
use typedialog_core::form_parser::FieldType;
|
||||
use typedialog_core::helpers::{transform_results, EncryptionContext};
|
||||
|
||||
let mut results = HashMap::new();
|
||||
results.insert("username".to_string(), json!("alice"));
|
||||
@ -1422,42 +1457,43 @@ fn test_mixed_sensitive_and_non_sensitive_fields() {
|
||||
results.insert("api_token".to_string(), json!("token_xyz"));
|
||||
results.insert("first_name".to_string(), json!("Alice"));
|
||||
|
||||
let make_basic_field = |name: &str, field_type: form_parser::FieldType, sensitive: Option<bool>| {
|
||||
form_parser::FieldDefinition {
|
||||
name: name.to_string(),
|
||||
field_type,
|
||||
prompt: format!("{}: ", name),
|
||||
default: None,
|
||||
placeholder: None,
|
||||
options: vec![],
|
||||
required: None,
|
||||
file_extension: None,
|
||||
prefix_text: None,
|
||||
page_size: None,
|
||||
vim_mode: None,
|
||||
custom_type: None,
|
||||
min_date: None,
|
||||
max_date: None,
|
||||
week_start: None,
|
||||
order: 0,
|
||||
when: None,
|
||||
i18n: None,
|
||||
group: None,
|
||||
nickel_contract: None,
|
||||
nickel_path: None,
|
||||
nickel_doc: None,
|
||||
nickel_alias: None,
|
||||
fragment: None,
|
||||
min_items: None,
|
||||
max_items: None,
|
||||
default_items: None,
|
||||
unique: None,
|
||||
unique_key: None,
|
||||
sensitive,
|
||||
encryption_backend: None,
|
||||
encryption_config: None,
|
||||
}
|
||||
};
|
||||
let make_basic_field =
|
||||
|name: &str, field_type: form_parser::FieldType, sensitive: Option<bool>| {
|
||||
form_parser::FieldDefinition {
|
||||
name: name.to_string(),
|
||||
field_type,
|
||||
prompt: format!("{}: ", name),
|
||||
default: None,
|
||||
placeholder: None,
|
||||
options: vec![],
|
||||
required: None,
|
||||
file_extension: None,
|
||||
prefix_text: None,
|
||||
page_size: None,
|
||||
vim_mode: None,
|
||||
custom_type: None,
|
||||
min_date: None,
|
||||
max_date: None,
|
||||
week_start: None,
|
||||
order: 0,
|
||||
when: None,
|
||||
i18n: None,
|
||||
group: None,
|
||||
nickel_contract: None,
|
||||
nickel_path: None,
|
||||
nickel_doc: None,
|
||||
nickel_alias: None,
|
||||
fragment: None,
|
||||
min_items: None,
|
||||
max_items: None,
|
||||
default_items: None,
|
||||
unique: None,
|
||||
unique_key: None,
|
||||
sensitive,
|
||||
encryption_backend: None,
|
||||
encryption_config: None,
|
||||
}
|
||||
};
|
||||
|
||||
let fields = vec![
|
||||
make_basic_field("username", FieldType::Text, Some(false)),
|
||||
@ -1468,8 +1504,7 @@ fn test_mixed_sensitive_and_non_sensitive_fields() {
|
||||
];
|
||||
|
||||
let context = EncryptionContext::redact_only();
|
||||
let redacted = transform_results(&results, &fields, &context, None)
|
||||
.expect("Transform failed");
|
||||
let redacted = transform_results(&results, &fields, &context, None).expect("Transform failed");
|
||||
|
||||
// Extract values with idiomatic error handling
|
||||
let get_str = |key: &str| redacted.get(key).and_then(|v| v.as_str()).unwrap_or("");
|
||||
|
||||
322
crates/typedialog-core/tests/proptest_validation.rs
Normal file
322
crates/typedialog-core/tests/proptest_validation.rs
Normal file
@ -0,0 +1,322 @@
|
||||
use proptest::prelude::*;
|
||||
use serde_json::json;
|
||||
use typedialog_core::nickel::ContractValidator;
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn test_email_validation_never_panics(input in "\\PC*") {
|
||||
let value = json!(input);
|
||||
drop(ContractValidator::validate_email(&value));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_valid_emails_accepted(
|
||||
local in "[a-zA-Z0-9._-]{1,20}",
|
||||
domain_parts in prop::collection::vec("[a-zA-Z0-9-]{1,10}", 2..5),
|
||||
tld in "[a-z]{2,6}"
|
||||
) {
|
||||
let domain = format!("{}.{}", domain_parts.join("."), tld);
|
||||
let email = format!("{}@{}", local, domain);
|
||||
let value = json!(email);
|
||||
|
||||
let result = ContractValidator::validate_email(&value);
|
||||
prop_assert!(
|
||||
result.is_ok(),
|
||||
"Valid email format '{}' was rejected: {:?}",
|
||||
email,
|
||||
result.err()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_emails_rejected_no_at(input in "[a-zA-Z0-9._-]+") {
|
||||
let value = json!(input);
|
||||
let result = ContractValidator::validate_email(&value);
|
||||
prop_assert!(result.is_err(), "Email without @ symbol should be rejected");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_emails_rejected_multiple_at(
|
||||
local in "[a-zA-Z0-9._-]{1,10}",
|
||||
domain in "[a-zA-Z0-9.-]{1,20}"
|
||||
) {
|
||||
let email = format!("{}@{}@extra", local, domain);
|
||||
let value = json!(email);
|
||||
let result = ContractValidator::validate_email(&value);
|
||||
prop_assert!(result.is_err(), "Email with multiple @ symbols should be rejected");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_emails_rejected_no_domain_dot(
|
||||
local in "[a-zA-Z0-9._-]{1,10}",
|
||||
domain in "[a-zA-Z0-9-]{1,10}"
|
||||
) {
|
||||
let email = format!("{}@{}", local, domain);
|
||||
let value = json!(email);
|
||||
let result = ContractValidator::validate_email(&value);
|
||||
prop_assert!(
|
||||
result.is_err(),
|
||||
"Email without dot in domain '{}' should be rejected",
|
||||
email
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_url_validation_never_panics(input in "\\PC*") {
|
||||
let value = json!(input);
|
||||
drop(ContractValidator::validate_url(&value));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_valid_urls_accepted(
|
||||
scheme in prop::sample::select(&["http://", "https://", "ftp://", "ftps://"]),
|
||||
host in "[a-zA-Z0-9.-]{1,30}",
|
||||
path in prop::option::of("[a-zA-Z0-9/_-]{0,50}")
|
||||
) {
|
||||
let url = if let Some(p) = path {
|
||||
format!("{}{}/{}", scheme, host, p)
|
||||
} else {
|
||||
format!("{}{}", scheme, host)
|
||||
};
|
||||
|
||||
let value = json!(url);
|
||||
let result = ContractValidator::validate_url(&value);
|
||||
prop_assert!(
|
||||
result.is_ok(),
|
||||
"Valid URL format '{}' was rejected: {:?}",
|
||||
url,
|
||||
result.err()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_urls_rejected_no_scheme(
|
||||
host in "[a-zA-Z0-9.-]{1,20}",
|
||||
path in "[a-zA-Z0-9/_-]{0,20}"
|
||||
) {
|
||||
let url = format!("{}/{}", host, path);
|
||||
let value = json!(url);
|
||||
let result = ContractValidator::validate_url(&value);
|
||||
prop_assert!(result.is_err(), "URL without scheme should be rejected");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_urls_rejected_invalid_scheme(
|
||||
scheme in "[a-z]{3,8}",
|
||||
host in "[a-zA-Z0-9.-]{1,20}"
|
||||
) {
|
||||
prop_assume!(!["http", "https", "ftp", "ftps"].contains(&scheme.as_str()));
|
||||
let url = format!("{}://{}", scheme, host);
|
||||
let value = json!(url);
|
||||
let result = ContractValidator::validate_url(&value);
|
||||
prop_assert!(
|
||||
result.is_err(),
|
||||
"URL with invalid scheme '{}' should be rejected",
|
||||
scheme
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_urls_rejected_empty_host(
|
||||
scheme in prop::sample::select(&["http://", "https://", "ftp://", "ftps://"])
|
||||
) {
|
||||
let url = scheme.to_string();
|
||||
let value = json!(url);
|
||||
let result = ContractValidator::validate_url(&value);
|
||||
prop_assert!(result.is_err(), "URL with empty host should be rejected");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_email_validation_type_safety(value in prop_oneof![
|
||||
Just(json!(42)),
|
||||
Just(json!(true)),
|
||||
Just(json!(null)),
|
||||
Just(json!([])),
|
||||
Just(json!({})),
|
||||
]) {
|
||||
let result = ContractValidator::validate_email(&value);
|
||||
prop_assert!(
|
||||
result.is_err(),
|
||||
"Email validation should reject non-string types"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_url_validation_type_safety(value in prop_oneof![
|
||||
Just(json!(42)),
|
||||
Just(json!(true)),
|
||||
Just(json!(null)),
|
||||
Just(json!([])),
|
||||
Just(json!({})),
|
||||
]) {
|
||||
let result = ContractValidator::validate_url(&value);
|
||||
prop_assert!(
|
||||
result.is_err(),
|
||||
"URL validation should reject non-string types"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_string_email_rejected(whitespace in "[ \\t\\n\\r]*") {
|
||||
let value = json!(whitespace);
|
||||
let result = ContractValidator::validate_email(&value);
|
||||
prop_assert!(
|
||||
result.is_err(),
|
||||
"Empty or whitespace-only email should be rejected"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_string_url_rejected(whitespace in "[ \\t\\n\\r]*") {
|
||||
let value = json!(whitespace);
|
||||
let result = ContractValidator::validate_url(&value);
|
||||
prop_assert!(
|
||||
result.is_err(),
|
||||
"Empty or whitespace-only URL should be rejected"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_email_domain_boundary_dots_rejected(
|
||||
local in "[a-zA-Z0-9._-]{1,10}",
|
||||
domain in "[a-zA-Z0-9-]{1,10}",
|
||||
tld in "[a-z]{2,6}"
|
||||
) {
|
||||
let email_start_dot = format!("{}@.{}.{}", local, domain, tld);
|
||||
let email_end_dot = format!("{}@{}.{}.", local, domain, tld);
|
||||
|
||||
let value_start = json!(email_start_dot);
|
||||
let result_start = ContractValidator::validate_email(&value_start);
|
||||
prop_assert!(
|
||||
result_start.is_err(),
|
||||
"Email with domain starting with dot should be rejected"
|
||||
);
|
||||
|
||||
let value_end = json!(email_end_dot);
|
||||
let result_end = ContractValidator::validate_email(&value_end);
|
||||
prop_assert!(
|
||||
result_end.is_err(),
|
||||
"Email with domain ending with dot should be rejected"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_contract_validate_integration_email(
|
||||
local in "[a-zA-Z0-9._-]{1,15}",
|
||||
domain in "[a-zA-Z0-9.-]{3,20}\\.[a-z]{2,6}"
|
||||
) {
|
||||
let email = format!("{}@{}", local, domain);
|
||||
let value = json!(email);
|
||||
|
||||
let result = ContractValidator::validate(&value, "String | std.string.Email");
|
||||
prop_assert!(
|
||||
result.is_ok() || result.is_err(),
|
||||
"Contract validation should never panic for any email input"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_contract_validate_integration_url(
|
||||
scheme in prop::sample::select(&["http", "https", "ftp", "ftps", "gopher", "file"]),
|
||||
host in "[a-zA-Z0-9.-]{1,30}"
|
||||
) {
|
||||
let url = format!("{}://{}", scheme, host);
|
||||
let value = json!(url);
|
||||
|
||||
let result = ContractValidator::validate(&value, "String | std.string.Url");
|
||||
prop_assert!(
|
||||
result.is_ok() || result.is_err(),
|
||||
"Contract validation should never panic for any URL input"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_known_valid_emails() {
|
||||
let valid_emails = vec![
|
||||
"user@example.com",
|
||||
"test.user@example.com",
|
||||
"user+tag@example.com",
|
||||
"user_name@example.co.uk",
|
||||
"a@b.c",
|
||||
"very.long.email.address@subdomain.example.org",
|
||||
];
|
||||
|
||||
for email in valid_emails {
|
||||
let value = json!(email);
|
||||
let result = ContractValidator::validate_email(&value);
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Valid email '{}' was rejected: {:?}",
|
||||
email,
|
||||
result.err()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_known_invalid_emails() {
|
||||
let invalid_emails = vec![
|
||||
"",
|
||||
"@",
|
||||
"@@",
|
||||
"user@",
|
||||
"@example.com",
|
||||
"user",
|
||||
"user@@example.com",
|
||||
"user@example",
|
||||
"user@.example.com",
|
||||
"user@example.com.",
|
||||
];
|
||||
|
||||
for email in invalid_emails {
|
||||
let value = json!(email);
|
||||
let result = ContractValidator::validate_email(&value);
|
||||
assert!(result.is_err(), "Invalid email '{}' was accepted", email);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_known_valid_urls() {
|
||||
let valid_urls = vec![
|
||||
"http://example.com",
|
||||
"https://example.com",
|
||||
"https://example.com/path",
|
||||
"https://subdomain.example.com/path/to/resource",
|
||||
"ftp://ftp.example.com",
|
||||
"ftps://secure.example.com",
|
||||
"http://localhost",
|
||||
"https://192.168.1.1",
|
||||
];
|
||||
|
||||
for url in valid_urls {
|
||||
let value = json!(url);
|
||||
let result = ContractValidator::validate_url(&value);
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Valid URL '{}' was rejected: {:?}",
|
||||
url,
|
||||
result.err()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_known_invalid_urls() {
|
||||
let invalid_urls = vec![
|
||||
"",
|
||||
"http://",
|
||||
"https://",
|
||||
"example.com",
|
||||
"ftp:/example.com",
|
||||
"gopher://example.com",
|
||||
"file://example.com",
|
||||
"htp://example.com",
|
||||
];
|
||||
|
||||
for url in invalid_urls {
|
||||
let value = json!(url);
|
||||
let result = ContractValidator::validate_url(&value);
|
||||
assert!(result.is_err(), "Invalid URL '{}' was accepted", url);
|
||||
}
|
||||
}
|
||||
@ -11,6 +11,11 @@ description = "TypeDialog TUI tool for interactive forms using ratatui"
|
||||
name = "typedialog-tui"
|
||||
path = "src/main.rs"
|
||||
|
||||
[package.metadata.binstall]
|
||||
pkg-url = "{ repo }/releases/download/v{ version }/typedialog-{ target }.tar.gz"
|
||||
bin-dir = "bin/{ bin }"
|
||||
pkg-fmt = "tgz"
|
||||
|
||||
[dependencies]
|
||||
typedialog-core = { path = "../typedialog-core", features = ["tui", "i18n", "encryption"] }
|
||||
clap = { workspace = true }
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
#![allow(clippy::result_large_err)]
|
||||
|
||||
//! typedialog-tui - Terminal UI tool for interactive forms
|
||||
//!
|
||||
//! A terminal UI (TUI) tool for creating interactive forms with enhanced visual presentation.
|
||||
@ -9,7 +11,7 @@ use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use typedialog_core::backends::{BackendFactory, BackendType};
|
||||
use typedialog_core::cli_common;
|
||||
use typedialog_core::config::TypeDialogConfig;
|
||||
use typedialog_core::config::{load_backend_config, TypeDialogConfig};
|
||||
use typedialog_core::helpers;
|
||||
use typedialog_core::i18n::{I18nBundle, LocaleLoader, LocaleResolver};
|
||||
use typedialog_core::nickel::{
|
||||
@ -30,6 +32,13 @@ struct Args {
|
||||
#[command(subcommand)]
|
||||
command: Option<Commands>,
|
||||
|
||||
/// TUI backend configuration file (TOML)
|
||||
///
|
||||
/// If provided, uses this file exclusively.
|
||||
/// If not provided, searches: ~/.config/typedialog/tui/{TYPEDIALOG_ENV}.toml → ~/.config/typedialog/tui/config.toml → defaults
|
||||
#[arg(global = true, short = 'c', long, value_name = "FILE")]
|
||||
backend_config: Option<PathBuf>,
|
||||
|
||||
/// Path to TOML form configuration file (for default form command)
|
||||
config: Option<PathBuf>,
|
||||
|
||||
@ -233,6 +242,15 @@ fn extract_nickel_defaults(
|
||||
async fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
// Load configuration with CLI override
|
||||
let config_path = args.backend_config.as_deref();
|
||||
let _config =
|
||||
load_backend_config::<TypeDialogConfig>("tui", config_path, TypeDialogConfig::default())?;
|
||||
|
||||
if let Some(path) = config_path {
|
||||
eprintln!("📋 Using config: {}", path.display());
|
||||
}
|
||||
|
||||
match args.command {
|
||||
Some(Commands::Form { config, defaults }) => {
|
||||
execute_form(config, defaults, &args.format, &args.out, &args.locale).await?;
|
||||
|
||||
@ -11,6 +11,11 @@ description = "TypeDialog Web server for interactive forms using axum"
|
||||
name = "typedialog-web"
|
||||
path = "src/main.rs"
|
||||
|
||||
[package.metadata.binstall]
|
||||
pkg-url = "{ repo }/releases/download/v{ version }/typedialog-{ target }.tar.gz"
|
||||
bin-dir = "bin/{ bin }"
|
||||
pkg-fmt = "tgz"
|
||||
|
||||
[dependencies]
|
||||
typedialog-core = { path = "../typedialog-core", features = ["web", "i18n", "encryption"] }
|
||||
clap = { workspace = true }
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
#![allow(clippy::result_large_err)]
|
||||
|
||||
//! typedialog-web - Web server for interactive forms
|
||||
//!
|
||||
//! A web server tool for creating interactive forms accessible via HTTP.
|
||||
@ -8,7 +10,7 @@ use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use typedialog_core::backends::{BackendFactory, BackendType};
|
||||
use typedialog_core::cli_common;
|
||||
use typedialog_core::config::TypeDialogConfig;
|
||||
use typedialog_core::config::{load_backend_config, TypeDialogConfig};
|
||||
use typedialog_core::i18n::{I18nBundle, LocaleLoader, LocaleResolver};
|
||||
use typedialog_core::{form_parser, helpers, Error, Result};
|
||||
use unic_langid::LanguageIdentifier;
|
||||
@ -24,6 +26,13 @@ struct Args {
|
||||
#[command(subcommand)]
|
||||
command: Option<Commands>,
|
||||
|
||||
/// Web backend configuration file (TOML)
|
||||
///
|
||||
/// If provided, uses this file exclusively.
|
||||
/// If not provided, searches: ~/.config/typedialog/web/{TYPEDIALOG_ENV}.toml → ~/.config/typedialog/web/config.toml → defaults
|
||||
#[arg(global = true, short = 'c', long, value_name = "FILE")]
|
||||
backend_config: Option<PathBuf>,
|
||||
|
||||
/// Path to TOML form configuration file
|
||||
config: Option<PathBuf>,
|
||||
|
||||
@ -168,6 +177,15 @@ fn extract_nickel_defaults(
|
||||
async fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
// Load configuration with CLI override
|
||||
let config_path = args.backend_config.as_deref();
|
||||
let _config =
|
||||
load_backend_config::<TypeDialogConfig>("web", config_path, TypeDialogConfig::default())?;
|
||||
|
||||
if let Some(path) = config_path {
|
||||
eprintln!("📋 Using config: {}", path.display());
|
||||
}
|
||||
|
||||
match args.command {
|
||||
Some(Commands::Form { config, defaults }) => {
|
||||
execute_form(
|
||||
|
||||
@ -11,6 +11,11 @@ description = "TypeDialog CLI tool for interactive forms and prompts"
|
||||
name = "typedialog"
|
||||
path = "src/main.rs"
|
||||
|
||||
[package.metadata.binstall]
|
||||
pkg-url = "{ repo }/releases/download/v{ version }/typedialog-{ target }.tar.gz"
|
||||
bin-dir = "bin/{ bin }"
|
||||
pkg-fmt = "tgz"
|
||||
|
||||
[dependencies]
|
||||
typedialog-core = { path = "../typedialog-core", features = ["cli", "i18n", "encryption"] }
|
||||
clap = { workspace = true }
|
||||
|
||||
@ -1,3 +1,6 @@
|
||||
#![allow(clippy::too_many_arguments)]
|
||||
#![allow(clippy::result_large_err)]
|
||||
|
||||
//! typedialog - Interactive forms and prompts CLI tool
|
||||
//!
|
||||
//! A powerful CLI tool for creating interactive forms and prompts using multiple backends.
|
||||
@ -10,7 +13,7 @@ use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use typedialog_core::backends::BackendFactory;
|
||||
use typedialog_core::cli_common;
|
||||
use typedialog_core::config::TypeDialogConfig;
|
||||
use typedialog_core::config::{load_backend_config, TypeDialogConfig};
|
||||
use typedialog_core::helpers;
|
||||
use typedialog_core::i18n::{I18nBundle, LocaleLoader, LocaleResolver};
|
||||
use typedialog_core::nickel::{
|
||||
@ -31,6 +34,13 @@ struct Cli {
|
||||
#[command(subcommand)]
|
||||
command: Commands,
|
||||
|
||||
/// CLI backend configuration file (TOML)
|
||||
///
|
||||
/// If provided, uses this file exclusively.
|
||||
/// If not provided, searches: ~/.config/typedialog/cli/{TYPEDIALOG_ENV}.toml → ~/.config/typedialog/cli/config.toml → defaults
|
||||
#[arg(global = true, short = 'c', long, value_name = "FILE")]
|
||||
config: Option<PathBuf>,
|
||||
|
||||
/// Output format: json, yaml, toml, or text
|
||||
#[arg(global = true, short, long, default_value = "text", help = cli_common::FORMAT_FLAG_HELP)]
|
||||
format: String,
|
||||
@ -300,6 +310,15 @@ enum Commands {
|
||||
async fn main() -> Result<()> {
|
||||
let cli = Cli::parse();
|
||||
|
||||
// Load configuration with CLI override
|
||||
let config_path = cli.config.as_deref();
|
||||
let _config =
|
||||
load_backend_config::<TypeDialogConfig>("cli", config_path, TypeDialogConfig::default())?;
|
||||
|
||||
if let Some(path) = config_path {
|
||||
eprintln!("📋 Using config: {}", path.display());
|
||||
}
|
||||
|
||||
match cli.command {
|
||||
Commands::Text {
|
||||
prompt,
|
||||
@ -727,7 +746,14 @@ async fn execute_form(
|
||||
};
|
||||
|
||||
let config = TypeDialogConfig::default();
|
||||
print_results(&results, format, output_file, &form_fields, &encryption_context, config.encryption.as_ref())?;
|
||||
print_results(
|
||||
&results,
|
||||
format,
|
||||
output_file,
|
||||
&form_fields,
|
||||
&encryption_context,
|
||||
config.encryption.as_ref(),
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@ -766,7 +792,8 @@ fn print_results(
|
||||
encryption_context: &helpers::EncryptionContext,
|
||||
global_config: Option<&typedialog_core::config::EncryptionDefaults>,
|
||||
) -> Result<()> {
|
||||
let output = helpers::format_results_secure(results, fields, format, encryption_context, global_config)?;
|
||||
let output =
|
||||
helpers::format_results_secure(results, fields, format, encryption_context, global_config)?;
|
||||
|
||||
if let Some(path) = output_file {
|
||||
fs::write(path, &output).map_err(Error::io)?;
|
||||
|
||||
31
examples/06-i18n/en-US.toml
Normal file
31
examples/06-i18n/en-US.toml
Normal file
@ -0,0 +1,31 @@
|
||||
# English translations (alternative TOML format)
|
||||
|
||||
[forms.registration]
|
||||
title = "User Registration"
|
||||
description = "Create a new user account"
|
||||
username-label = "Username"
|
||||
username-prompt = "Please enter a username"
|
||||
username-placeholder = "user123"
|
||||
email-label = "Email Address"
|
||||
email-prompt = "Please enter your email address"
|
||||
email-placeholder = "user@example.com"
|
||||
|
||||
[forms.registration.roles]
|
||||
admin = "Administrator"
|
||||
user = "Regular User"
|
||||
guest = "Guest"
|
||||
developer = "Developer"
|
||||
|
||||
[forms.employee-onboarding]
|
||||
title = "Employee Onboarding"
|
||||
description = "Complete your onboarding process"
|
||||
welcome = "Welcome to the team!"
|
||||
full-name-prompt = "What is your full name?"
|
||||
department-prompt = "Which department are you joining?"
|
||||
start-date-prompt = "What is your start date?"
|
||||
|
||||
[forms.feedback]
|
||||
title = "Feedback Form"
|
||||
overall-satisfaction-prompt = "How satisfied are you with our service?"
|
||||
improvement-prompt = "What could we improve?"
|
||||
contact-prompt = "Can we contact you with follow-up questions?"
|
||||
37
examples/06-i18n/en-US/forms.ftl
Normal file
37
examples/06-i18n/en-US/forms.ftl
Normal file
@ -0,0 +1,37 @@
|
||||
# English translations for common form fields
|
||||
|
||||
## Registration form
|
||||
registration-title = User Registration
|
||||
registration-description = Create a new user account
|
||||
registration-username-label = Username
|
||||
registration-username-prompt = Please enter a username
|
||||
registration-username-placeholder = user123
|
||||
registration-email-label = Email Address
|
||||
registration-email-prompt = Please enter your email address
|
||||
registration-email-placeholder = user@example.com
|
||||
registration-password-label = Password
|
||||
registration-password-prompt = Please enter a password
|
||||
registration-password-placeholder = ••••••••
|
||||
registration-confirm-label = I agree to the terms and conditions
|
||||
registration-confirm-prompt = Do you agree to the terms and conditions?
|
||||
|
||||
## Role selection
|
||||
role-prompt = Please select your role
|
||||
role-admin = Administrator
|
||||
role-user = Regular User
|
||||
role-guest = Guest
|
||||
role-developer = Developer
|
||||
|
||||
## Common actions
|
||||
action-submit = Submit
|
||||
action-cancel = Cancel
|
||||
action-next = Next
|
||||
action-previous = Previous
|
||||
action-confirm = Confirm
|
||||
action-decline = Decline
|
||||
|
||||
## Common validation messages
|
||||
error-required = This field is required
|
||||
error-invalid-email = Please enter a valid email address
|
||||
error-password-too-short = Password must be at least 8 characters
|
||||
error-passwords-mismatch = Passwords do not match
|
||||
31
examples/06-i18n/es-ES.toml
Normal file
31
examples/06-i18n/es-ES.toml
Normal file
@ -0,0 +1,31 @@
|
||||
# Traducciones al español (formato TOML alternativo)
|
||||
|
||||
[forms.registration]
|
||||
title = "Registro de Usuario"
|
||||
description = "Crear una nueva cuenta de usuario"
|
||||
username-label = "Nombre de usuario"
|
||||
username-prompt = "Por favor, ingrese su nombre de usuario"
|
||||
username-placeholder = "usuario123"
|
||||
email-label = "Correo electrónico"
|
||||
email-prompt = "Por favor, ingrese su correo electrónico"
|
||||
email-placeholder = "usuario@ejemplo.com"
|
||||
|
||||
[forms.registration.roles]
|
||||
admin = "Administrador"
|
||||
user = "Usuario Regular"
|
||||
guest = "Invitado"
|
||||
developer = "Desarrollador"
|
||||
|
||||
[forms.employee-onboarding]
|
||||
title = "Incorporación de Empleado"
|
||||
description = "Complete su proceso de incorporación"
|
||||
welcome = "¡Bienvenido al equipo!"
|
||||
full-name-prompt = "¿Cuál es su nombre completo?"
|
||||
department-prompt = "¿A cuál departamento se está uniendo?"
|
||||
start-date-prompt = "¿Cuál es su fecha de inicio?"
|
||||
|
||||
[forms.feedback]
|
||||
title = "Formulario de Retroalimentación"
|
||||
overall-satisfaction-prompt = "¿Cuán satisfecho está con nuestro servicio?"
|
||||
improvement-prompt = "¿Qué podríamos mejorar?"
|
||||
contact-prompt = "¿Podemos contactarlo con preguntas de seguimiento?"
|
||||
37
examples/06-i18n/es-ES/forms.ftl
Normal file
37
examples/06-i18n/es-ES/forms.ftl
Normal file
@ -0,0 +1,37 @@
|
||||
# Traducciones al español para formularios comunes
|
||||
|
||||
## Formulario de registro
|
||||
registration-title = Registro de Usuario
|
||||
registration-description = Crear una nueva cuenta de usuario
|
||||
registration-username-label = Nombre de usuario
|
||||
registration-username-prompt = Por favor, ingrese su nombre de usuario
|
||||
registration-username-placeholder = usuario123
|
||||
registration-email-label = Correo electrónico
|
||||
registration-email-prompt = Por favor, ingrese su correo electrónico
|
||||
registration-email-placeholder = usuario@ejemplo.com
|
||||
registration-password-label = Contraseña
|
||||
registration-password-prompt = Por favor, ingrese su contraseña
|
||||
registration-password-placeholder = ••••••••
|
||||
registration-confirm-label = Acepto los términos y condiciones
|
||||
registration-confirm-prompt = ¿Acepta los términos y condiciones?
|
||||
|
||||
## Selección de rol
|
||||
role-prompt = Por favor, seleccione su rol
|
||||
role-admin = Administrador
|
||||
role-user = Usuario Regular
|
||||
role-guest = Invitado
|
||||
role-developer = Desarrollador
|
||||
|
||||
## Acciones comunes
|
||||
action-submit = Enviar
|
||||
action-cancel = Cancelar
|
||||
action-next = Siguiente
|
||||
action-previous = Anterior
|
||||
action-confirm = Confirmar
|
||||
action-decline = Rechazar
|
||||
|
||||
## Mensajes de validación comunes
|
||||
error-required = Este campo es requerido
|
||||
error-invalid-email = Por favor, ingrese una dirección de correo válida
|
||||
error-password-too-short = La contraseña debe tener al menos 8 caracteres
|
||||
error-passwords-mismatch = Las contraseñas no coinciden
|
||||
@ -282,8 +282,8 @@ key = "~/.age/key.txt"
|
||||
|
||||
## Documentation References
|
||||
|
||||
- **Setup Guide**: See `docs/ENCRYPTION-SERVICES-SETUP.md`
|
||||
- **Quick Start**: See `docs/ENCRYPTION-QUICK-START.md`
|
||||
- **Setup Guide**: See `docs/encryption-services-setup.md`
|
||||
- **Quick Start**: See `docs/encryption-quick-start.md`
|
||||
- **Implementation Status**: See `docs/ENCRYPTION-IMPLEMENTATION-STATUS.md`
|
||||
|
||||
---
|
||||
|
||||
@ -434,6 +434,6 @@ $ typedialog form examples/08-encryption/simple-login.toml \
|
||||
## See Also
|
||||
|
||||
- [Examples Directory](./README.md)
|
||||
- [Encryption Architecture Guide](../../docs/ENCRYPTION-UNIFIED-ARCHITECTURE.md)
|
||||
- [Encryption Architecture Guide](../../docs/encryption-unified-architecture.md)
|
||||
- [SOPS GitHub](https://github.com/getsops/sops)
|
||||
- [Age GitHub](https://github.com/FiloSottile/age)
|
||||
|
||||
@ -535,4 +535,4 @@ After verifying SOPS works:
|
||||
- [SOPS GitHub](https://github.com/getsops/sops)
|
||||
- [Age GitHub](https://github.com/FiloSottile/age)
|
||||
- [typedialog Encryption Examples](./README.md)
|
||||
- [Encryption Architecture](../../docs/ENCRYPTION-UNIFIED-ARCHITECTURE.md)
|
||||
- [Encryption Architecture](../../docs/encryption-unified-architecture.md)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user