103 lines
2.4 KiB
Rust
103 lines
2.4 KiB
Rust
|
|
use async_trait::async_trait;
|
||
|
|
use serde::{Deserialize, Serialize};
|
||
|
|
|
||
|
|
use crate::{error::EmbeddingError, traits::Embedding};
|
||
|
|
|
||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
|
|
pub struct SearchResult {
|
||
|
|
pub id: String,
|
||
|
|
pub score: f32,
|
||
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||
|
|
pub embedding: Option<Embedding>,
|
||
|
|
pub metadata: serde_json::Value,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Debug, Clone, Default)]
|
||
|
|
pub struct SearchFilter {
|
||
|
|
pub metadata: Option<serde_json::Value>,
|
||
|
|
pub min_score: Option<f32>,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Debug, Clone)]
|
||
|
|
pub struct VectorStoreConfig {
|
||
|
|
pub dimensions: usize,
|
||
|
|
pub metric: DistanceMetric,
|
||
|
|
pub options: serde_json::Value,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl VectorStoreConfig {
|
||
|
|
pub fn new(dimensions: usize) -> Self {
|
||
|
|
Self {
|
||
|
|
dimensions,
|
||
|
|
metric: DistanceMetric::default(),
|
||
|
|
options: serde_json::json!({}),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn with_metric(mut self, metric: DistanceMetric) -> Self {
|
||
|
|
self.metric = metric;
|
||
|
|
self
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn with_options(mut self, options: serde_json::Value) -> Self {
|
||
|
|
self.options = options;
|
||
|
|
self
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Debug, Clone, Copy, Default)]
|
||
|
|
pub enum DistanceMetric {
|
||
|
|
#[default]
|
||
|
|
Cosine,
|
||
|
|
Euclidean,
|
||
|
|
DotProduct,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[async_trait]
|
||
|
|
pub trait VectorStore: Send + Sync {
|
||
|
|
fn name(&self) -> &str;
|
||
|
|
fn dimensions(&self) -> usize;
|
||
|
|
|
||
|
|
async fn upsert(
|
||
|
|
&self,
|
||
|
|
id: &str,
|
||
|
|
embedding: &Embedding,
|
||
|
|
metadata: serde_json::Value,
|
||
|
|
) -> Result<(), EmbeddingError>;
|
||
|
|
|
||
|
|
async fn upsert_batch(
|
||
|
|
&self,
|
||
|
|
items: Vec<(String, Embedding, serde_json::Value)>,
|
||
|
|
) -> Result<(), EmbeddingError> {
|
||
|
|
for (id, embedding, metadata) in items {
|
||
|
|
self.upsert(&id, &embedding, metadata).await?;
|
||
|
|
}
|
||
|
|
Ok(())
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn search(
|
||
|
|
&self,
|
||
|
|
embedding: &Embedding,
|
||
|
|
limit: usize,
|
||
|
|
filter: Option<SearchFilter>,
|
||
|
|
) -> Result<Vec<SearchResult>, EmbeddingError>;
|
||
|
|
|
||
|
|
async fn get(&self, id: &str) -> Result<Option<SearchResult>, EmbeddingError>;
|
||
|
|
|
||
|
|
async fn delete(&self, id: &str) -> Result<bool, EmbeddingError>;
|
||
|
|
|
||
|
|
async fn delete_batch(&self, ids: &[&str]) -> Result<usize, EmbeddingError> {
|
||
|
|
let mut count = 0;
|
||
|
|
for id in ids {
|
||
|
|
if self.delete(id).await? {
|
||
|
|
count += 1;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
Ok(count)
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn flush(&self) -> Result<(), EmbeddingError>;
|
||
|
|
|
||
|
|
async fn count(&self) -> Result<usize, EmbeddingError>;
|
||
|
|
}
|