103 lines
2.4 KiB
Rust
Raw Normal View History

use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::{error::EmbeddingError, traits::Embedding};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResult {
pub id: String,
pub score: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub embedding: Option<Embedding>,
pub metadata: serde_json::Value,
}
#[derive(Debug, Clone, Default)]
pub struct SearchFilter {
pub metadata: Option<serde_json::Value>,
pub min_score: Option<f32>,
}
#[derive(Debug, Clone)]
pub struct VectorStoreConfig {
pub dimensions: usize,
pub metric: DistanceMetric,
pub options: serde_json::Value,
}
impl VectorStoreConfig {
pub fn new(dimensions: usize) -> Self {
Self {
dimensions,
metric: DistanceMetric::default(),
options: serde_json::json!({}),
}
}
pub fn with_metric(mut self, metric: DistanceMetric) -> Self {
self.metric = metric;
self
}
pub fn with_options(mut self, options: serde_json::Value) -> Self {
self.options = options;
self
}
}
#[derive(Debug, Clone, Copy, Default)]
pub enum DistanceMetric {
#[default]
Cosine,
Euclidean,
DotProduct,
}
#[async_trait]
pub trait VectorStore: Send + Sync {
fn name(&self) -> &str;
fn dimensions(&self) -> usize;
async fn upsert(
&self,
id: &str,
embedding: &Embedding,
metadata: serde_json::Value,
) -> Result<(), EmbeddingError>;
async fn upsert_batch(
&self,
items: Vec<(String, Embedding, serde_json::Value)>,
) -> Result<(), EmbeddingError> {
for (id, embedding, metadata) in items {
self.upsert(&id, &embedding, metadata).await?;
}
Ok(())
}
async fn search(
&self,
embedding: &Embedding,
limit: usize,
filter: Option<SearchFilter>,
) -> Result<Vec<SearchResult>, EmbeddingError>;
async fn get(&self, id: &str) -> Result<Option<SearchResult>, EmbeddingError>;
async fn delete(&self, id: &str) -> Result<bool, EmbeddingError>;
async fn delete_batch(&self, ids: &[&str]) -> Result<usize, EmbeddingError> {
let mut count = 0;
for id in ids {
if self.delete(id).await? {
count += 1;
}
}
Ok(count)
}
async fn flush(&self) -> Result<(), EmbeddingError>;
async fn count(&self) -> Result<usize, EmbeddingError>;
}