From 4f83c8603b452d735672660ca53209da0195b226 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jesu=CC=81s=20Pe=CC=81rez?= Date: Wed, 24 Dec 2025 03:22:04 +0000 Subject: [PATCH] chore: add typedialog-ag LLM and agents with MDX --- agents/README.md | 241 +++++ agents/architect.agent.mdx | 59 ++ agents/code-reviewer.agent.mdx | 39 + agents/debugger.agent.mdx | 48 + agents/doc-generator.agent.mdx | 48 + agents/greeting.agent.mdx | 21 + agents/refactor.agent.mdx | 53 ++ agents/summarizer.agent.mdx | 36 + agents/test-generator.agent.mdx | 41 + agents/translator.agent.mdx | 32 + crates/typedialog-agent/README.md | 114 +++ crates/typedialog-agent/quickstart.md | 245 ++++++ .../typedialog-ag-core/Cargo.toml | 73 ++ .../typedialog-ag-core/LLM_INTEGRATION.md | 357 ++++++++ .../examples/llm_execution.rs | 105 +++ .../examples/provider_comparison.rs | 277 ++++++ .../typedialog-ag-core/src/cache/mod.rs | 474 ++++++++++ .../typedialog-ag-core/src/error.rs | 260 ++++++ .../typedialog-ag-core/src/executor/mod.rs | 829 ++++++++++++++++++ .../typedialog-ag-core/src/formats/mod.rs | 39 + .../typedialog-ag-core/src/lib.rs | 183 ++++ .../typedialog-ag-core/src/llm/claude.rs | 517 +++++++++++ .../typedialog-ag-core/src/llm/gemini.rs | 555 ++++++++++++ .../typedialog-ag-core/src/llm/mod.rs | 70 ++ .../typedialog-ag-core/src/llm/ollama.rs | 504 +++++++++++ .../typedialog-ag-core/src/llm/openai.rs | 457 ++++++++++ .../typedialog-ag-core/src/llm/provider.rs | 83 ++ .../typedialog-ag-core/src/nickel/mod.rs | 190 ++++ .../typedialog-ag-core/src/parser/ast.rs | 58 ++ .../src/parser/directives.rs | 28 + .../typedialog-ag-core/src/parser/markdown.rs | 1 + .../typedialog-ag-core/src/parser/mdx.rs | 265 ++++++ .../typedialog-ag-core/src/parser/mod.rs | 56 ++ .../typedialog-ag-core/src/transpiler/mod.rs | 281 ++++++ .../typedialog-ag-core/src/utils/mod.rs | 1 + .../tests/fixtures/architect.agent.mdx | 33 + .../tests/fixtures/code-reviewer.agent.mdx | 34 + .../tests/fixtures/haiku.agent.mdx | 10 + .../tests/fixtures/simple.agent.mdx | 8 + .../tests/integration_test.rs | 295 +++++++ .../tests/simple_integration_test.rs | 90 ++ .../typedialog-agent/typedialog-ag/Cargo.toml | 53 ++ .../typedialog-agent/typedialog-ag/README.md | 445 ++++++++++ .../typedialog-agent/typedialog-ag/src/lib.rs | 395 +++++++++ .../typedialog-ag/src/main.rs | 336 +++++++ 45 files changed, 8339 insertions(+) create mode 100644 agents/README.md create mode 100644 agents/architect.agent.mdx create mode 100644 agents/code-reviewer.agent.mdx create mode 100644 agents/debugger.agent.mdx create mode 100644 agents/doc-generator.agent.mdx create mode 100644 agents/greeting.agent.mdx create mode 100644 agents/refactor.agent.mdx create mode 100644 agents/summarizer.agent.mdx create mode 100644 agents/test-generator.agent.mdx create mode 100644 agents/translator.agent.mdx create mode 100644 crates/typedialog-agent/README.md create mode 100644 crates/typedialog-agent/quickstart.md create mode 100644 crates/typedialog-agent/typedialog-ag-core/Cargo.toml create mode 100644 crates/typedialog-agent/typedialog-ag-core/LLM_INTEGRATION.md create mode 100644 crates/typedialog-agent/typedialog-ag-core/examples/llm_execution.rs create mode 100644 crates/typedialog-agent/typedialog-ag-core/examples/provider_comparison.rs create mode 100644 crates/typedialog-agent/typedialog-ag-core/src/cache/mod.rs create mode 100644 crates/typedialog-agent/typedialog-ag-core/src/error.rs create mode 100644 crates/typedialog-agent/typedialog-ag-core/src/executor/mod.rs create mode 100644 crates/typedialog-agent/typedialog-ag-core/src/formats/mod.rs create mode 100644 crates/typedialog-agent/typedialog-ag-core/src/lib.rs create mode 100644 crates/typedialog-agent/typedialog-ag-core/src/llm/claude.rs create mode 100644 crates/typedialog-agent/typedialog-ag-core/src/llm/gemini.rs create mode 100644 crates/typedialog-agent/typedialog-ag-core/src/llm/mod.rs create mode 100644 crates/typedialog-agent/typedialog-ag-core/src/llm/ollama.rs create mode 100644 crates/typedialog-agent/typedialog-ag-core/src/llm/openai.rs create mode 100644 crates/typedialog-agent/typedialog-ag-core/src/llm/provider.rs create mode 100644 crates/typedialog-agent/typedialog-ag-core/src/nickel/mod.rs create mode 100644 crates/typedialog-agent/typedialog-ag-core/src/parser/ast.rs create mode 100644 crates/typedialog-agent/typedialog-ag-core/src/parser/directives.rs create mode 100644 crates/typedialog-agent/typedialog-ag-core/src/parser/markdown.rs create mode 100644 crates/typedialog-agent/typedialog-ag-core/src/parser/mdx.rs create mode 100644 crates/typedialog-agent/typedialog-ag-core/src/parser/mod.rs create mode 100644 crates/typedialog-agent/typedialog-ag-core/src/transpiler/mod.rs create mode 100644 crates/typedialog-agent/typedialog-ag-core/src/utils/mod.rs create mode 100644 crates/typedialog-agent/typedialog-ag-core/tests/fixtures/architect.agent.mdx create mode 100644 crates/typedialog-agent/typedialog-ag-core/tests/fixtures/code-reviewer.agent.mdx create mode 100644 crates/typedialog-agent/typedialog-ag-core/tests/fixtures/haiku.agent.mdx create mode 100644 crates/typedialog-agent/typedialog-ag-core/tests/fixtures/simple.agent.mdx create mode 100644 crates/typedialog-agent/typedialog-ag-core/tests/integration_test.rs create mode 100644 crates/typedialog-agent/typedialog-ag-core/tests/simple_integration_test.rs create mode 100644 crates/typedialog-agent/typedialog-ag/Cargo.toml create mode 100644 crates/typedialog-agent/typedialog-ag/README.md create mode 100644 crates/typedialog-agent/typedialog-ag/src/lib.rs create mode 100644 crates/typedialog-agent/typedialog-ag/src/main.rs diff --git a/agents/README.md b/agents/README.md new file mode 100644 index 0000000..344f551 --- /dev/null +++ b/agents/README.md @@ -0,0 +1,241 @@ +# TypeDialog Agent Examples + +This directory contains example agents demonstrating various capabilities of the TypeDialog Agent system. + +## Available Agents + +### 1. greeting.agent.mdx +**Purpose**: Simple friendly greeting +**Inputs**: `name` (String) +**LLM**: Claude 3.5 Haiku + +**Usage**: +```bash +# CLI +typedialog-ag agents/greeting.agent.mdx + +# HTTP +curl -X POST http://localhost:8765/agents/greeting/execute \ + -H "Content-Type: application/json" \ + -d '{"name":"Alice"}' +``` + +### 2. code-reviewer.agent.mdx +**Purpose**: Comprehensive code review +**Inputs**: `language` (String), `code` (String) +**LLM**: Claude Opus 4.5 +**Validation**: Must contain "## Review" and "## Suggestions" + +**Usage**: +```bash +curl -X POST http://localhost:8765/agents/code-reviewer/execute \ + -H "Content-Type: application/json" \ + -d '{ + "language": "rust", + "code": "fn add(a: i32, b: i32) -> i32 { a + b }" + }' +``` + +### 3. architect.agent.mdx +**Purpose**: Software architecture design +**Inputs**: `feature` (String), `tech_stack?` (String, optional) +**LLM**: Claude Opus 4.5 +**Validation**: Must contain "## Architecture", "## Components", "## Data Flow" + +**Usage**: +```bash +curl -X POST http://localhost:8765/agents/architect/execute \ + -H "Content-Type: application/json" \ + -d '{ + "feature": "Real-time chat system", + "tech_stack": "Rust, WebSockets, PostgreSQL" + }' +``` + +### 4. summarizer.agent.mdx +**Purpose**: Text summarization +**Inputs**: `text` (String), `style?` (String, optional) +**LLM**: Claude 3.5 Sonnet + +**Usage**: +```bash +curl -X POST http://localhost:8765/agents/summarizer/execute \ + -H "Content-Type: application/json" \ + -d '{ + "text": "Long article text here...", + "style": "technical" + }' +``` + +### 5. test-generator.agent.mdx +**Purpose**: Generate unit tests +**Inputs**: `language` (String), `function_code` (String) +**LLM**: Claude 3.5 Sonnet +**Validation**: Must contain "test" and "assert" + +**Usage**: +```bash +curl -X POST http://localhost:8765/agents/test-generator/execute \ + -H "Content-Type: application/json" \ + -d '{ + "language": "python", + "function_code": "def factorial(n):\\n return 1 if n == 0 else n * factorial(n-1)" + }' +``` + +### 6. doc-generator.agent.mdx +**Purpose**: Generate technical documentation +**Inputs**: `code` (String), `language` (String) +**LLM**: Claude 3.5 Sonnet + +**Usage**: +```bash +curl -X POST http://localhost:8765/agents/doc-generator/execute \ + -H "Content-Type: application/json" \ + -d '{ + "language": "javascript", + "code": "function debounce(fn, delay) { ... }" + }' +``` + +### 7. translator.agent.mdx +**Purpose**: Language translation +**Inputs**: `text` (String), `source_lang` (String), `target_lang` (String) +**LLM**: Claude 3.5 Sonnet + +**Usage**: +```bash +curl -X POST http://localhost:8765/agents/translator/execute \ + -H "Content-Type: application/json" \ + -d '{ + "text": "Hello, how are you?", + "source_lang": "English", + "target_lang": "Spanish" + }' +``` + +### 8. debugger.agent.mdx +**Purpose**: Debug code issues +**Inputs**: `code` (String), `error` (String), `language` (String) +**LLM**: Claude Opus 4.5 +**Validation**: Must contain "## Root Cause" and "## Fix" + +**Usage**: +```bash +curl -X POST http://localhost:8765/agents/debugger/execute \ + -H "Content-Type: application/json" \ + -d '{ + "language": "rust", + "code": "let x = vec![1,2,3]; println!(\"{}\", x[5]);", + "error": "index out of bounds: the len is 3 but the index is 5" + }' +``` + +### 9. refactor.agent.mdx +**Purpose**: Code refactoring +**Inputs**: `code` (String), `language` (String), `goal?` (String, optional) +**LLM**: Claude Opus 4.5 +**Validation**: Must contain "## Refactored Code" and "## Changes" + +**Usage**: +```bash +curl -X POST http://localhost:8765/agents/refactor/execute \ + -H "Content-Type: application/json" \ + -d '{ + "language": "typescript", + "code": "function calc(a,b,op){if(op==='+')return a+b;if(op==='-')return a-b;}", + "goal": "Improve readability and type safety" + }' +``` + +## Features Demonstrated + +### Input Types +- **Required inputs**: `name`, `code`, `language` +- **Optional inputs**: `style?`, `tech_stack?`, `goal?` + +### LLM Models +- **Claude 3.5 Haiku**: Fast, cost-effective (greeting) +- **Claude 3.5 Sonnet**: Balanced performance (summarizer, test-generator, doc-generator, translator) +- **Claude Opus 4.5**: Maximum capability (code-reviewer, architect, debugger, refactor) + +### Validation Rules +- **must_contain**: Ensure specific sections in output +- **format**: Enforce output format (markdown, text) +- **min_length/max_length**: Control output size + +### Temperature Settings +- **0.3**: Deterministic, factual (code review, debugging, translation) +- **0.4-0.5**: Balanced creativity (architecture, refactoring) +- **0.8**: More creative (greeting) + +### Max Tokens +- **150**: Short responses (greeting) +- **500-1000**: Medium responses (summarizer, translator) +- **2000-3000**: Detailed responses (code review, documentation) +- **4000**: Comprehensive responses (architecture) + +## Testing Agents + +### CLI Testing +```bash +# Validate agent +typedialog-ag validate agents/greeting.agent.mdx + +# Transpile to Nickel +typedialog-ag transpile agents/greeting.agent.mdx + +# Execute (interactive) +typedialog-ag agents/greeting.agent.mdx +``` + +### HTTP Server Testing +```bash +# Start server +typedialog-ag serve --port 8765 + +# Test health +curl http://localhost:8765/health + +# Validate agent +curl -X POST http://localhost:8765/validate \ + -H "Content-Type: application/json" \ + -d '{"agent_file":"agents/greeting.agent.mdx"}' + +# Execute agent +curl -X POST http://localhost:8765/agents/greeting/execute \ + -H "Content-Type: application/json" \ + -d '{"name":"World"}' +``` + +## Best Practices + +1. **Clear Role**: Define specific role in `@agent` directive +2. **Type Inputs**: Declare input types explicitly +3. **Validate Output**: Use `@validate` to ensure quality +4. **Right Model**: Choose LLM based on task complexity +5. **Temperature**: Lower for factual, higher for creative +6. **Token Limits**: Set appropriate max_tokens for task +7. **Prompts**: Be specific and provide clear instructions +8. **Examples**: Include examples in prompt when helpful + +## Environment Variables + +Ensure you have API keys configured: +```bash +export ANTHROPIC_API_KEY=sk-ant-... +export OPENAI_API_KEY=sk-... # If using OpenAI models +``` + +## Cache + +The system automatically caches transpiled Nickel code: +```bash +# Check cache stats +typedialog-ag cache stats + +# Clear cache +typedialog-ag cache clear +``` + +Cache location: `~/.typeagent/cache/` diff --git a/agents/architect.agent.mdx b/agents/architect.agent.mdx new file mode 100644 index 0000000..ec77ae6 --- /dev/null +++ b/agents/architect.agent.mdx @@ -0,0 +1,59 @@ +--- +@agent { + role: software architect, + llm: claude-opus-4-5-20251101, + max_tokens: 4000, + temperature: 0.5 +} + +@input feature: String +@input tech_stack?: String + +@validate output { + must_contain: ["## Architecture", "## Components", "## Data Flow"], + format: markdown, + min_length: 500 +} +--- + +# Architecture Design Request + +**Feature**: {{ feature }} + +{% if tech_stack %} +**Tech Stack**: {{ tech_stack }} +{% else %} +**Tech Stack**: Choose appropriate technologies based on requirements +{% endif %} + +As a senior software architect, design a robust, scalable architecture for this feature. + +Provide: + +## Architecture +- High-level system design +- Architectural patterns used +- Why this approach + +## Components +- Key components/modules +- Responsibilities of each +- Communication patterns + +## Data Flow +- How data moves through the system +- State management approach +- API contracts (if applicable) + +## Technical Considerations +- Scalability approach +- Performance considerations +- Security measures +- Trade-offs made + +## Implementation Strategy +- Suggested implementation order +- Critical path items +- Risk mitigation + +Keep it practical and implementation-ready. diff --git a/agents/code-reviewer.agent.mdx b/agents/code-reviewer.agent.mdx new file mode 100644 index 0000000..241049d --- /dev/null +++ b/agents/code-reviewer.agent.mdx @@ -0,0 +1,39 @@ +--- +@agent { + role: senior code reviewer, + llm: claude-opus-4-5-20251101, + max_tokens: 2000, + temperature: 0.3 +} + +@input language: String +@input code: String + +@validate output { + must_contain: ["## Review", "## Suggestions"], + format: markdown, + min_length: 100 +} +--- + +# Code Review Request + +**Language**: {{ language }} + +**Code to review**: +```{{ language }} +{{ code }} +``` + +Please provide a thorough code review covering: + +1. **Code Quality**: Structure, readability, maintainability +2. **Best Practices**: Language-specific conventions and patterns +3. **Potential Issues**: Bugs, edge cases, security concerns +4. **Performance**: Efficiency and optimization opportunities +5. **Suggestions**: Concrete improvements with examples + +Format your response with clear sections: +## Review +## Suggestions +## Security Considerations (if applicable) diff --git a/agents/debugger.agent.mdx b/agents/debugger.agent.mdx new file mode 100644 index 0000000..4f80225 --- /dev/null +++ b/agents/debugger.agent.mdx @@ -0,0 +1,48 @@ +--- +@agent { + role: debugging expert, + llm: claude-opus-4-5-20251101, + max_tokens: 2000, + temperature: 0.3 +} + +@input code: String +@input error: String +@input language: String + +@validate output { + must_contain: ["## Root Cause", "## Fix"], + format: markdown, + min_length: 150 +} +--- + +# Debug Request + +**Language**: {{ language }} + +**Code with issue**: +```{{ language }} +{{ code }} +``` + +**Error/Problem**: +``` +{{ error }} +``` + +Please analyze this issue and provide: + +## Root Cause +What's causing the problem and why + +## Fix +Corrected code with explanation of changes + +## Prevention +How to avoid similar issues in the future + +## Testing +How to verify the fix works + +Be specific and provide complete, working code. diff --git a/agents/doc-generator.agent.mdx b/agents/doc-generator.agent.mdx new file mode 100644 index 0000000..bcef87c --- /dev/null +++ b/agents/doc-generator.agent.mdx @@ -0,0 +1,48 @@ +--- +@agent { + role: technical writer, + llm: claude-3-5-sonnet-20241022, + max_tokens: 3000, + temperature: 0.5 +} + +@input code: String +@input language: String + +@validate output { + must_contain: ["##"], + format: markdown, + min_length: 200 +} +--- + +# Documentation Generation + +**Language**: {{ language }} + +**Code**: +```{{ language }} +{{ code }} +``` + +Generate comprehensive documentation for this code including: + +## Overview +Brief description of what this code does + +## API Reference +- Function/class signatures +- Parameters with types and descriptions +- Return values +- Exceptions/errors + +## Usage Examples +Practical examples showing how to use this code + +## Implementation Notes +Key implementation details developers should know + +## Complexity +Time and space complexity if relevant + +Write clear, concise documentation that helps developers understand and use this code effectively. diff --git a/agents/greeting.agent.mdx b/agents/greeting.agent.mdx new file mode 100644 index 0000000..b28b5bb --- /dev/null +++ b/agents/greeting.agent.mdx @@ -0,0 +1,21 @@ +--- +@agent { + role: friendly assistant, + llm: claude-3-5-haiku-20241022, + max_tokens: 150, + temperature: 0.8 +} + +@input name: String + +@validate output { + min_length: 10, + max_length: 300, + format: text +} +--- + +Hello {{ name }}! + +Please respond with a warm, friendly greeting. Make it personal and cheerful. +Keep it brief (2-3 sentences). diff --git a/agents/refactor.agent.mdx b/agents/refactor.agent.mdx new file mode 100644 index 0000000..586c94e --- /dev/null +++ b/agents/refactor.agent.mdx @@ -0,0 +1,53 @@ +--- +@agent { + role: refactoring specialist, + llm: claude-opus-4-5-20251101, + max_tokens: 3000, + temperature: 0.4 +} + +@input code: String +@input language: String +@input goal?: String + +@validate output { + must_contain: ["## Refactored Code", "## Changes"], + format: markdown, + min_length: 200 +} +--- + +# Refactoring Request + +**Language**: {{ language }} + +**Current code**: +```{{ language }} +{{ code }} +``` + +{% if goal %} +**Refactoring goal**: {{ goal }} +{% else %} +**Goal**: Improve code quality, readability, and maintainability +{% endif %} + +Please refactor this code following these principles: +- DRY (Don't Repeat Yourself) +- SOLID principles +- Clean code practices +- Language-specific idioms + +Provide: + +## Refactored Code +Complete refactored version + +## Changes Made +Specific improvements with rationale + +## Benefits +How the refactored code is better + +## Considerations +Any trade-offs or notes about the refactoring diff --git a/agents/summarizer.agent.mdx b/agents/summarizer.agent.mdx new file mode 100644 index 0000000..04c6640 --- /dev/null +++ b/agents/summarizer.agent.mdx @@ -0,0 +1,36 @@ +--- +@agent { + role: content summarizer, + llm: claude-3-5-sonnet-20241022, + max_tokens: 500, + temperature: 0.3 +} + +@input text: String +@input style?: String + +@validate output { + max_length: 1000, + format: text, + min_length: 50 +} +--- + +# Summarization Task + +**Content to summarize**: +{{ text }} + +{% if style %} +**Requested style**: {{ style }} +{% else %} +**Style**: Concise and clear +{% endif %} + +Please provide a well-structured summary that: +- Captures the key points +- Maintains accuracy +- Is easy to understand +- Follows the requested style + +Use bullet points for clarity if appropriate. diff --git a/agents/test-generator.agent.mdx b/agents/test-generator.agent.mdx new file mode 100644 index 0000000..80efe5a --- /dev/null +++ b/agents/test-generator.agent.mdx @@ -0,0 +1,41 @@ +--- +@agent { + role: test engineer, + llm: claude-3-5-sonnet-20241022, + max_tokens: 2000, + temperature: 0.4 +} + +@input language: String +@input function_code: String + +@validate output { + must_contain: ["test", "assert"], + format: text, + min_length: 100 +} +--- + +# Test Generation Request + +**Language**: {{ language }} + +**Function to test**: +```{{ language }} +{{ function_code }} +``` + +Generate comprehensive unit tests for this function covering: + +1. **Happy path**: Normal expected usage +2. **Edge cases**: Boundary conditions, empty inputs, etc. +3. **Error cases**: Invalid inputs, error handling +4. **Special cases**: Any language-specific considerations + +Provide: +- Complete, runnable test code +- Clear test names describing what's being tested +- Assertions that verify expected behavior +- Comments explaining non-obvious test cases + +Use idiomatic {{ language }} testing conventions. diff --git a/agents/translator.agent.mdx b/agents/translator.agent.mdx new file mode 100644 index 0000000..f9b92f0 --- /dev/null +++ b/agents/translator.agent.mdx @@ -0,0 +1,32 @@ +--- +@agent { + role: translator, + llm: claude-3-5-sonnet-20241022, + max_tokens: 1000, + temperature: 0.3 +} + +@input text: String +@input source_lang: String +@input target_lang: String + +@validate output { + min_length: 5, + format: text +} +--- + +# Translation Task + +Translate the following text from {{ source_lang }} to {{ target_lang }}. + +**Source text** ({{ source_lang }}): +{{ text }} + +Requirements: +- Maintain the original meaning and tone +- Use natural, idiomatic {{ target_lang }} +- Preserve any technical terms appropriately +- Keep formatting if present + +Provide only the translation, no explanations. diff --git a/crates/typedialog-agent/README.md b/crates/typedialog-agent/README.md new file mode 100644 index 0000000..5ec6814 --- /dev/null +++ b/crates/typedialog-agent/README.md @@ -0,0 +1,114 @@ +# typedialog-agent + +> Type-safe AI agent execution with 3-layer validation pipeline + +Part of the [TypeDialog](https://github.com/yourusername/typedialog) ecosystem. + +## Features + +- **3-Layer Pipeline**: MDX → Nickel → MD with validation at each step +- **Type Safety**: Compile-time type checking via Nickel +- **Multi-Format**: .agent.mdx, .agent.ncl, .agent.md support +- **CLI + Server**: Execute locally or via HTTP API +- **Cache Layer**: Memory + disk caching for fast re-execution +- **Integration**: Works with typedialog-ai for form-based agent creation + +## Architecture + +``` +Layer 1: Markup Parser + .agent.mdx → AST + Parse @directives, {{variables}}, markdown + +Layer 2: Nickel Transpiler + Evaluator + AST → Nickel code → Type check → AgentDefinition + +Layer 3: Executor + AgentDefinition + Inputs → LLM → Validated Output +``` + +## Quick Start + +### CLI Usage + +```bash +# Execute agent +typeagent architect.agent.mdx --input feature_name="authentication" + +# Transpile to Nickel (inspect generated code) +typeagent transpile architect.agent.mdx -o architect.agent.ncl + +# Validate without execution +typeagent validate architect.agent.mdx + +# Start HTTP server +typeagent serve --port 8765 +``` + +### Programmatic Usage + +```rust +use typedialog_ag_core::{AgentLoader, AgentFormat}; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let loader = AgentLoader::new(); + + // Load agent + let agent = loader.load(Path::new("architect.agent.mdx")).await?; + + // Execute with inputs + let inputs = [("feature_name", "auth")].into_iter() + .map(|(k, v)| (k.to_string(), serde_json::Value::String(v.to_string()))) + .collect(); + + let result = loader.execute(&agent, inputs).await?; + + println!("{}", result.output); + Ok(()) +} +``` + +## Integration with TypeDialog Ecosystem + +### With typedialog-ai + +```bash +# 1. Create agent using AI-assisted form +typedialog form agent-builder.toml --backend ai +# → Generates architect.agent.mdx + +# 2. Execute with typeagent +typeagent architect.agent.mdx +``` + +### With Vapora (MCP Plugin) + +```rust +// Vapora uses typedialog-ag-core as library +use typedialog_ag_core::AgentLoader; + +let loader = AgentLoader::new(); +let agent = loader.load(Path::new("agents/architect.agent.mdx")).await?; +// Execute via Vapora orchestration +``` + +## Project Structure + +``` +typedialog-agent/ +├── typedialog-ag-core/ # Core library (reusable) +├── typedialog-ag/ # CLI binary +└── typedialog-ag-server/ # HTTP server +``` + +## Documentation + +- [Implementation Plan](../../.coder/2025-12-23-typedialog-agent-implementation.plan.md) +- [Markup Syntax](./docs/MARKUP_SYNTAX.md) +- [Nickel Integration](./docs/nickel.md) +- [HTTP API](./docs/API.md) + +## License + +MIT diff --git a/crates/typedialog-agent/quickstart.md b/crates/typedialog-agent/quickstart.md new file mode 100644 index 0000000..e15e63d --- /dev/null +++ b/crates/typedialog-agent/quickstart.md @@ -0,0 +1,245 @@ +# TypeAgent Quick Start + +Get started with TypeAgent in 5 minutes. + +## Prerequisites + +- Rust toolchain (1.75+) +- Anthropic API key + +## Setup + +### 1. Set API Key + +```bash +export ANTHROPIC_API_KEY=your-api-key-here +``` + +### 2. Build TypeAgent + +```bash +cd crates/typedialog-agent/typedialog-ag +cargo build --release +``` + +### 3. Add to PATH (optional) + +```bash +export PATH="$PWD/target/release:$PATH" +``` + +## Your First Agent + +### Create an Agent File + +Create `hello.agent.mdx`: + +```markdown +--- +@agent { + role: friendly assistant, + llm: claude-3-5-haiku-20241022 +} + +@input name: String +--- + +Say hello to {{ name }} in a creative and friendly way! +``` + +### Run It + +```bash +typeagent hello.agent.mdx +``` + +You'll see: + +``` +🤖 TypeAgent Executor + +✓ Parsed agent definition +✓ Transpiled to Nickel +✓ Evaluated agent definition + +Agent Configuration: + Role: friendly assistant + Model: claude-3-5-haiku-20241022 + Max tokens: 4096 + Temperature: 0.7 + +name (String): Alice█ +``` + +Type a name and press Enter. The agent will execute and show the response! + +## Next Steps + +### Try the Examples + +```bash +# Simple greeting +typeagent tests/fixtures/simple.agent.mdx --yes + +# Creative haiku +typeagent tests/fixtures/haiku.agent.mdx +``` + +### Validate Before Running + +```bash +typeagent validate hello.agent.mdx +``` + +### See the Nickel Code + +```bash +typeagent transpile hello.agent.mdx +``` + +## Common Workflows + +### Development Workflow + +```bash +# 1. Write your agent +vim agent.mdx + +# 2. Validate it +typeagent validate agent.mdx + +# 3. Test with verbose output +typeagent agent.mdx --verbose + +# 4. Run in production +typeagent agent.mdx +``` + +### Quick Iteration + +Use `--yes` to skip prompts during development: + +```bash +# Edit agent.mdx +# Run without prompts +typeagent agent.mdx --yes +``` + +## Advanced Features + +### Context Injection + +Import files into your agent: + +```markdown +@import "./docs/**/*.md" as documentation +@shell "git log --oneline -5" as recent_commits +``` + +### Output Validation + +Ensure output meets requirements: + +```markdown +@validate output { + must_contain: ["Security", "Performance"], + format: markdown, + min_length: 100 +} +``` + +### Conditional Logic + +Use Tera template syntax: + +```markdown +{% if has_description %} +Description: {{ description }} +{% endif %} +``` + +## Troubleshooting + +### "ANTHROPIC_API_KEY not set" + +```bash +export ANTHROPIC_API_KEY=sk-ant-... +``` + +### "Failed to parse agent MDX" + +Check your frontmatter syntax: + +```markdown +--- +@agent { + role: assistant, # <- comma required + llm: claude-3-5-haiku-20241022 +} +--- +``` + +### "Permission denied" + +```bash +chmod +x ./target/release/typeagent +``` + +## Learn More + +- [CLI Documentation](typedialog-ag/README.md) +- [LLM Integration Guide](typedialog-ag-core/LLM_INTEGRATION.md) +- [Example Agents](typedialog-ag-core/tests/fixtures/) + +## Support + +- GitHub Issues: https://github.com/jesusperezlorenzo/typedialog/issues +- API Docs: `cargo doc --open --package typedialog-ag-core` + +## Next Example: Architecture Agent + +Create `architect.agent.mdx`: + +```markdown +--- +@agent { + role: software architect, + llm: claude-3-5-sonnet-20241022 +} + +@input feature_name: String +@input requirements?: String + +@validate output { + must_contain: ["## Architecture", "## Components"], + format: markdown, + min_length: 200 +} +--- + +# Architecture Design: {{ feature_name }} + +You are an experienced software architect. Design a comprehensive architecture for: + +**Feature**: {{ feature_name }} + +{% if requirements %} +**Requirements**: {{ requirements }} +{% endif %} + +Provide: +1. High-level architecture overview +2. Component breakdown +3. Data flow +4. Technology recommendations + +Use clear markdown formatting with sections. +``` + +Run it: + +```bash +typeagent architect.agent.mdx +``` + +The agent will prompt for inputs and generate a complete architecture design! diff --git a/crates/typedialog-agent/typedialog-ag-core/Cargo.toml b/crates/typedialog-agent/typedialog-ag-core/Cargo.toml new file mode 100644 index 0000000..dbf2c23 --- /dev/null +++ b/crates/typedialog-agent/typedialog-ag-core/Cargo.toml @@ -0,0 +1,73 @@ +[package] +name = "typedialog-ag-core" +version.workspace = true +edition.workspace = true +rust-version.workspace = true +authors.workspace = true +license.workspace = true +repository.workspace = true +description = "Core library for type-safe AI agent execution with MDX → Nickel → MD pipeline" +keywords.workspace = true +categories.workspace = true + +[dependencies] +# Async +tokio = { workspace = true } +futures = { workspace = true } +async-trait = { workspace = true } + +# Nickel +nickel-lang-core = { workspace = true } + +# Templates +tera = { workspace = true } + +# Serialization +serde = { workspace = true } +serde_json = { workspace = true } +serde_yaml = { workspace = true } +toml = { workspace = true } + +# HTTP Client +reqwest = { workspace = true } + +# Glob & Files +globset = { workspace = true } +ignore = { workspace = true } + +# Cache +lru = { workspace = true } +sha2 = { workspace = true } +hex = { workspace = true } +bincode = { workspace = true } + +# Parsing +nom = { workspace = true } + +# Error handling +thiserror = { workspace = true } +anyhow = { workspace = true } + +# Logging +tracing = { workspace = true } + +# Utilities +uuid = { workspace = true } +chrono = { workspace = true } +dirs = { workspace = true } + +[dev-dependencies] +proptest.workspace = true +criterion.workspace = true + +[features] +default = ["markup", "nickel", "cache"] +markup = [] # MDX parsing + transpiler +nickel = [] # Nickel evaluation +markdown = [] # Legacy .agent.md support +cache = [] # Cache layer + +[lib] +name = "typedialog_ag_core" +path = "src/lib.rs" + diff --git a/crates/typedialog-agent/typedialog-ag-core/LLM_INTEGRATION.md b/crates/typedialog-agent/typedialog-ag-core/LLM_INTEGRATION.md new file mode 100644 index 0000000..bfc2e04 --- /dev/null +++ b/crates/typedialog-agent/typedialog-ag-core/LLM_INTEGRATION.md @@ -0,0 +1,357 @@ +# LLM Integration + +TypeAgent Core now includes full LLM execution capabilities, allowing agents to call real language models. + +## Supported Providers + +### Claude (Anthropic) +- ✅ Fully supported with streaming +- Models: `claude-3-5-haiku-20241022`, `claude-3-5-sonnet-20241022`, `claude-opus-4`, etc. +- Requires: `ANTHROPIC_API_KEY` environment variable +- Features: Full SSE streaming, token usage tracking + +### OpenAI +- ✅ Fully supported with streaming +- Models: `gpt-4o`, `gpt-4o-mini`, `gpt-4-turbo`, `o1`, `o3`, `o4-mini`, etc. +- Requires: `OPENAI_API_KEY` environment variable +- Features: Full SSE streaming, token usage tracking + +### Google Gemini +- ✅ Fully supported with streaming +- Models: `gemini-2.0-flash-exp`, `gemini-1.5-pro`, `gemini-1.5-flash`, etc. +- Requires: `GEMINI_API_KEY` or `GOOGLE_API_KEY` environment variable +- Features: Full JSON streaming, token usage tracking +- Note: Assistant role is mapped to "model" in Gemini API + +### Ollama (Local Models) +- ✅ Fully supported with streaming +- Models: `llama2`, `mistral`, `phi`, `codellama`, `mixtral`, `qwen`, etc. +- Requires: Ollama running locally (default: http://localhost:11434) +- Optional: `OLLAMA_BASE_URL` to override endpoint +- Features: Full JSON streaming, token usage tracking, privacy (local execution) +- Note: No API key required - runs entirely on your machine + +## Setup + +### 1. Set API Key + +**For Claude:** +```bash +export ANTHROPIC_API_KEY=your-api-key-here +``` + +**For OpenAI:** +```bash +export OPENAI_API_KEY=your-api-key-here +``` + +**For Gemini:** +```bash +export GEMINI_API_KEY=your-api-key-here +# Or use GOOGLE_API_KEY +export GOOGLE_API_KEY=your-api-key-here +``` + +**For Ollama (local models):** +```bash +# Install and start Ollama first +# Download from: https://ollama.ai +ollama serve # Start the Ollama server + +# Pull a model (in another terminal) +ollama pull llama2 + +# Optional: Override default URL +export OLLAMA_BASE_URL=http://localhost:11434 +``` + +### 2. Create Agent MDX File + +```markdown +--- +@agent { + role: assistant, + llm: claude-3-5-haiku-20241022 +} +--- + +Hello {{ name }}! How can I help you today? +``` + +### 3. Execute Agent + +```rust +use typedialog_ag_core::{MarkupParser, NickelTranspiler, NickelEvaluator, AgentExecutor}; +use std::collections::HashMap; + +#[tokio::main] +async fn main() -> Result<(), Box> { + // Parse MDX + let parser = MarkupParser::new(); + let ast = parser.parse(mdx_content)?; + + // Transpile to Nickel + let transpiler = NickelTranspiler::new(); + let nickel_code = transpiler.transpile(&ast)?; + + // Evaluate to AgentDefinition + let evaluator = NickelEvaluator::new(); + let agent_def = evaluator.evaluate(&nickel_code)?; + + // Execute with LLM + let executor = AgentExecutor::new(); + let mut inputs = HashMap::new(); + inputs.insert("name".to_string(), serde_json::json!("Alice")); + + let result = executor.execute(&agent_def, inputs).await?; + println!("Response: {}", result.output); + println!("Tokens: {}", result.metadata.tokens.unwrap_or(0)); + + Ok(()) +} +``` + +## Configuration + +Agent configuration is specified in the MDX frontmatter: + +```yaml +@agent { + role: creative writer, # System prompt role + llm: claude-3-5-haiku-20241022, # Model name + tools: [], # Tool calling (future) + max_tokens: 4096, # Optional (default: 4096) + temperature: 0.7 # Optional (default: 0.7) +} +``` + +## LLM Provider Architecture + +### Provider Trait + +```rust +#[async_trait] +pub trait LlmProvider: Send + Sync { + async fn complete(&self, request: LlmRequest) -> Result; + fn name(&self) -> &str; +} +``` + +### Request/Response + +```rust +pub struct LlmRequest { + pub model: String, + pub messages: Vec, + pub max_tokens: Option, + pub temperature: Option, + pub system: Option, +} + +pub struct LlmResponse { + pub content: String, + pub model: String, + pub usage: Option, +} +``` + +### Automatic Provider Selection + +The executor automatically selects the correct provider based on model name: + +- `claude-*`, `anthropic-*` → ClaudeProvider +- `gpt-*`, `o1-*`, `o3-*`, `o4-*` → OpenAIProvider +- `gemini-*` → GeminiProvider +- `llama*`, `mistral*`, `phi*`, `codellama*`, `mixtral*`, `qwen*`, etc. → OllamaProvider + +## Examples + +### Run Complete Pipeline + +```bash +cargo run --example llm_execution +``` + +### Compare All Providers + +```bash +# Run all four providers with the same prompt +cargo run --example provider_comparison + +# Run specific provider only +cargo run --example provider_comparison claude +cargo run --example provider_comparison openai +cargo run --example provider_comparison gemini +cargo run --example provider_comparison ollama +``` + +### Run with Test (requires API key) + +```bash +cargo test --package typedialog-ag-core -- test_execute_with_real_llm --exact --ignored --nocapture +``` + +### Integration Test + +```bash +cargo test --package typedialog-ag-core --test simple_integration_test -- test_complete_pipeline_with_llm --exact --ignored --nocapture +``` + +## Error Handling + +```rust +match executor.execute(&agent_def, inputs).await { + Ok(result) => { + if !result.validation_passed { + eprintln!("Validation errors: {:?}", result.validation_errors); + } + println!("Output: {}", result.output); + } + Err(e) => { + if e.to_string().contains("ANTHROPIC_API_KEY") { + eprintln!("Error: API key not set"); + } else { + eprintln!("Execution failed: {}", e); + } + } +} +``` + +## Token Usage Tracking + +All LLM responses include token usage information: + +```rust +let result = executor.execute(&agent_def, inputs).await?; + +if let Some(tokens) = result.metadata.tokens { + println!("Tokens used: {}", tokens); +} + +if let Some(usage) = response.usage { + println!("Input tokens: {}", usage.input_tokens); + println!("Output tokens: {}", usage.output_tokens); + println!("Total tokens: {}", usage.total_tokens); +} +``` + +## Context Injection + +Agents can load context from files, URLs, and shell commands before LLM execution: + +```markdown +--- +@agent { + role: code reviewer, + llm: claude-3-5-sonnet-20241022 +} + +@import "./src/**/*.rs" as source_code +@shell "git diff HEAD~1" as recent_changes + +--- + +Review the following code: + +**Source Code:** +{{ source_code }} + +**Recent Changes:** +{{ recent_changes }} + +Provide security and performance analysis. +``` + +The executor loads all context before calling the LLM, so the model receives the fully rendered prompt with all imported content. + +## Validation + +Output validation runs automatically after LLM execution: + +```markdown +--- +@validate output { + must_contain: ["Security", "Performance"], + format: markdown, + min_length: 100 +} +--- +``` + +Validation results are included in `ExecutionResult`: + +```rust +if !result.validation_passed { + for error in result.validation_errors { + eprintln!("Validation error: {}", error); + } +} +``` + +## Cost Optimization + +### Use Appropriate Models + +- `claude-3-5-haiku-20241022`: Fast, cheap, good for simple tasks +- `claude-3-5-sonnet-20241022`: Balanced performance and cost +- `claude-opus-4`: Most capable, highest cost + +### Limit Token Usage + +```rust +agent_def.config.max_tokens = 500; // Limit response length +``` + +### Cache Context + +The executor supports context caching to avoid re-loading files on each execution (implementation varies by provider). + +## Testing Without API Key + +Tests that require real LLM execution are marked with `#[ignore]`: + +```bash +# Run only non-LLM tests +cargo test --package typedialog-ag-core + +# Run LLM tests (requires ANTHROPIC_API_KEY) +cargo test --package typedialog-ag-core -- --ignored +``` + +## Implementation Files + +- **Provider Trait**: `src/llm/provider.rs` +- **Claude Client**: `src/llm/claude.rs` +- **OpenAI Client**: `src/llm/openai.rs` +- **Gemini Client**: `src/llm/gemini.rs` +- **Ollama Client**: `src/llm/ollama.rs` +- **Executor Integration**: `src/executor/mod.rs` +- **Example**: `examples/llm_execution.rs` +- **Multi-Provider Demo**: `examples/provider_comparison.rs` +- **Tests**: `tests/simple_integration_test.rs` + +## Streaming Support + +All four providers (Claude, OpenAI, Gemini, Ollama) support real-time streaming: + +```rust +use typedialog_ag_core::AgentExecutor; + +let executor = AgentExecutor::new(); +let result = executor.execute_streaming(&agent_def, inputs, |chunk| { + print!("{}", chunk); + std::io::stdout().flush().unwrap(); +}).await?; + +println!("\n\nFinal output: {}", result.output); +println!("Tokens: {:?}", result.metadata.tokens); +``` + +The CLI automatically uses streaming for real-time token display. + +### Token Usage in Streaming + +- **Claude**: ✅ Provides token usage in stream (via `message_delta` event) +- **OpenAI**: ❌ No token usage in stream (API limitation - only in non-streaming mode) +- **Gemini**: ✅ Provides token usage in stream (via `usageMetadata` in final chunk) +- **Ollama**: ✅ Provides token usage in stream (via `prompt_eval_count`/`eval_count` in done event) diff --git a/crates/typedialog-agent/typedialog-ag-core/examples/llm_execution.rs b/crates/typedialog-agent/typedialog-ag-core/examples/llm_execution.rs new file mode 100644 index 0000000..6ba1f82 --- /dev/null +++ b/crates/typedialog-agent/typedialog-ag-core/examples/llm_execution.rs @@ -0,0 +1,105 @@ +//! Example demonstrating complete LLM execution pipeline +//! +//! This example shows how to: +//! 1. Parse an MDX agent file +//! 2. Transpile to Nickel +//! 3. Evaluate to AgentDefinition +//! 4. Execute with real LLM (Claude) +//! +//! Requirements: +//! - Set ANTHROPIC_API_KEY environment variable +//! +//! Run with: +//! ```bash +//! export ANTHROPIC_API_KEY=your-api-key +//! cargo run --example llm_execution +//! ``` + +use std::collections::HashMap; +use typedialog_ag_core::executor::AgentExecutor; +use typedialog_ag_core::nickel::NickelEvaluator; +use typedialog_ag_core::parser::MarkupParser; +use typedialog_ag_core::transpiler::NickelTranspiler; + +#[tokio::main] +async fn main() -> Result<(), Box> { + println!("=== TypeAgent LLM Execution Example ===\n"); + + // Define a simple agent in MDX format + let mdx_content = r#"--- +@agent { + role: creative writer, + llm: claude-3-5-haiku-20241022 +} +--- + +Write a haiku about {{ topic }}. Return only the haiku, nothing else. +"#; + + println!("1. Parsing MDX agent definition..."); + let parser = MarkupParser::new(); + let ast = parser.parse(mdx_content)?; + println!(" ✓ Parsed {} AST nodes\n", ast.nodes.len()); + + println!("2. Transpiling to Nickel..."); + let transpiler = NickelTranspiler::new(); + let nickel_code = transpiler.transpile(&ast)?; + println!(" ✓ Generated Nickel code:\n"); + println!("{}\n", nickel_code); + + println!("3. Evaluating Nickel to AgentDefinition..."); + let evaluator = NickelEvaluator::new(); + let mut agent_def = evaluator.evaluate(&nickel_code)?; + + // Configure for production + agent_def.config.max_tokens = 150; + agent_def.config.temperature = 0.8; + + println!(" ✓ Agent configured:"); + println!(" - Role: {}", agent_def.config.role); + println!(" - Model: {}", agent_def.config.llm); + println!(" - Max tokens: {}", agent_def.config.max_tokens); + println!(" - Temperature: {}\n", agent_def.config.temperature); + + println!("4. Executing agent with LLM..."); + let executor = AgentExecutor::new(); + + // Provide input + let mut inputs = HashMap::new(); + inputs.insert( + "topic".to_string(), + serde_json::json!("programming in Rust"), + ); + + // Execute (calls Claude API) + println!(" → Calling Claude API..."); + let result = executor.execute(&agent_def, inputs).await?; + + println!("\n=== RESULT ===\n"); + println!("{}\n", result.output); + + println!("=== METADATA ==="); + println!("Duration: {}ms", result.metadata.duration_ms.unwrap_or(0)); + println!("Tokens used: {}", result.metadata.tokens.unwrap_or(0)); + println!( + "Model: {}", + result + .metadata + .model + .as_ref() + .unwrap_or(&"unknown".to_string()) + ); + + if result.validation_passed { + println!("Validation: ✓ PASSED"); + } else { + println!("Validation: ✗ FAILED"); + for error in &result.validation_errors { + println!(" - {}", error); + } + } + + println!("\n✓ Example completed successfully!"); + + Ok(()) +} diff --git a/crates/typedialog-agent/typedialog-ag-core/examples/provider_comparison.rs b/crates/typedialog-agent/typedialog-ag-core/examples/provider_comparison.rs new file mode 100644 index 0000000..f62502c --- /dev/null +++ b/crates/typedialog-agent/typedialog-ag-core/examples/provider_comparison.rs @@ -0,0 +1,277 @@ +//! Provider Comparison Demo +//! +//! Demonstrates all three LLM providers (Claude, OpenAI, Gemini) with the same prompt, +//! showing both blocking and streaming modes. +//! +//! Usage: +//! # Run all providers +//! cargo run --example provider_comparison +//! +//! # Run specific provider +//! cargo run --example provider_comparison claude +//! cargo run --example provider_comparison openai +//! cargo run --example provider_comparison gemini +//! +//! Environment variables required: +//! ANTHROPIC_API_KEY - for Claude +//! OPENAI_API_KEY - for OpenAI +//! GEMINI_API_KEY or GOOGLE_API_KEY - for Gemini + +use futures::stream::StreamExt; +use std::io::{self, Write}; +use std::time::Instant; +use typedialog_ag_core::llm::{ + ClaudeProvider, GeminiProvider, LlmMessage, LlmProvider, LlmRequest, MessageRole, + OllamaProvider, OpenAIProvider, StreamChunk, +}; + +const PROMPT: &str = + "Write a haiku about artificial intelligence. Return only the haiku, nothing else."; + +struct ProviderConfig { + name: &'static str, + model: &'static str, + color: &'static str, +} + +const CLAUDE_CONFIG: ProviderConfig = ProviderConfig { + name: "Claude", + model: "claude-3-5-haiku-20241022", + color: "\x1b[35m", // Magenta +}; + +const OPENAI_CONFIG: ProviderConfig = ProviderConfig { + name: "OpenAI", + model: "gpt-4o-mini", + color: "\x1b[32m", // Green +}; + +const GEMINI_CONFIG: ProviderConfig = ProviderConfig { + name: "Gemini", + model: "gemini-2.0-flash-exp", + color: "\x1b[34m", // Blue +}; + +const OLLAMA_CONFIG: ProviderConfig = ProviderConfig { + name: "Ollama", + model: "llama2", + color: "\x1b[33m", // Yellow +}; + +const RESET: &str = "\x1b[0m"; +const BOLD: &str = "\x1b[1m"; +const DIM: &str = "\x1b[2m"; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let args: Vec = std::env::args().collect(); + let provider_filter = args.get(1).map(|s| s.as_str()); + + println!("{}{}", BOLD, "═".repeat(70)); + println!("🤖 TypeAgent Multi-Provider Demo"); + println!("{}{}\n", "═".repeat(70), RESET); + + println!("{}Prompt:{} {}\n", BOLD, RESET, PROMPT); + + // Determine which providers to run + let run_claude = provider_filter.is_none() || provider_filter == Some("claude"); + let run_openai = provider_filter.is_none() || provider_filter == Some("openai"); + let run_gemini = provider_filter.is_none() || provider_filter == Some("gemini"); + let run_ollama = provider_filter.is_none() || provider_filter == Some("ollama"); + + // Run blocking mode comparison + println!("{}{}", BOLD, "━".repeat(70)); + println!("📋 BLOCKING MODE (complete() method)"); + println!("{}{}\n", "━".repeat(70), RESET); + + if run_claude { + run_blocking_demo(&CLAUDE_CONFIG, ClaudeProvider::new()).await; + } + + if run_openai { + run_blocking_demo(&OPENAI_CONFIG, OpenAIProvider::new()).await; + } + + if run_gemini { + run_blocking_demo(&GEMINI_CONFIG, GeminiProvider::new()).await; + } + + if run_ollama { + run_blocking_demo(&OLLAMA_CONFIG, OllamaProvider::new()).await; + } + + // Run streaming mode comparison + println!("\n{}{}", BOLD, "━".repeat(70)); + println!("⚡ STREAMING MODE (stream() method)"); + println!("{}{}\n", "━".repeat(70), RESET); + + if run_claude { + run_streaming_demo(&CLAUDE_CONFIG, ClaudeProvider::new()).await; + } + + if run_openai { + run_streaming_demo(&OPENAI_CONFIG, OpenAIProvider::new()).await; + } + + if run_gemini { + run_streaming_demo(&GEMINI_CONFIG, GeminiProvider::new()).await; + } + + if run_ollama { + run_streaming_demo(&OLLAMA_CONFIG, OllamaProvider::new()).await; + } + + // Summary + println!("\n{}{}", BOLD, "═".repeat(70)); + println!("📊 SUMMARY"); + println!("{}{}\n", "═".repeat(70), RESET); + + println!("{}All four providers support:{}", BOLD, RESET); + println!(" ✅ Blocking requests (complete())"); + println!(" ✅ Streaming responses (stream())"); + println!(" ✅ Token usage tracking"); + println!(" ✅ Error handling"); + println!(" ✅ System message support"); + println!("\n{}Provider-specific features:{}", BOLD, RESET); + println!(" 🟣 Claude: SSE streaming, usage in stream, cloud"); + println!(" 🟢 OpenAI: SSE streaming, no usage in stream, cloud"); + println!(" 🔵 Gemini: JSON streaming, usage in stream, cloud, 'model' role"); + println!(" 🟡 Ollama: JSON streaming, usage in stream, local models"); + + Ok(()) +} + +async fn run_blocking_demo( + config: &ProviderConfig, + provider_result: Result, +) { + println!( + "{}{}{} ({}){}", + config.color, BOLD, config.name, config.model, RESET + ); + println!( + "{}────────────────────────────────────────────────────────────────────{}", + DIM, RESET + ); + + let provider = match provider_result { + Ok(p) => p, + Err(e) => { + println!("{}❌ Error: {}{}\n", config.color, e, RESET); + return; + } + }; + + let request = LlmRequest { + model: config.model.to_string(), + messages: vec![LlmMessage { + role: MessageRole::User, + content: PROMPT.to_string(), + }], + max_tokens: Some(100), + temperature: Some(0.7), + system: Some("You are a creative poet.".to_string()), + }; + + let start = Instant::now(); + + match provider.complete(request).await { + Ok(response) => { + let duration = start.elapsed(); + + println!("{}Response:{}", BOLD, RESET); + println!("{}{}{}", config.color, response.content, RESET); + + println!("\n{}Metadata:{}", DIM, RESET); + println!(" Duration: {}ms", duration.as_millis()); + if let Some(usage) = response.usage { + println!( + " Tokens: {} (in: {}, out: {})", + usage.total_tokens, usage.input_tokens, usage.output_tokens + ); + } + println!(); + } + Err(e) => { + println!("{}❌ Error: {}{}\n", config.color, e, RESET); + } + } +} + +async fn run_streaming_demo( + config: &ProviderConfig, + provider_result: Result, +) { + println!( + "{}{}{} ({}){}", + config.color, BOLD, config.name, config.model, RESET + ); + println!( + "{}────────────────────────────────────────────────────────────────────{}", + DIM, RESET + ); + + let provider = match provider_result { + Ok(p) => p, + Err(e) => { + println!("{}❌ Error: {}{}\n", config.color, e, RESET); + return; + } + }; + + let request = LlmRequest { + model: config.model.to_string(), + messages: vec![LlmMessage { + role: MessageRole::User, + content: PROMPT.to_string(), + }], + max_tokens: Some(100), + temperature: Some(0.7), + system: Some("You are a creative poet.".to_string()), + }; + + let start = Instant::now(); + + match provider.stream(request).await { + Ok(mut stream) => { + print!("{}Response: {}{}", BOLD, RESET, config.color); + io::stdout().flush().unwrap(); + + let mut full_output = String::new(); + let mut tokens = None; + + while let Some(chunk_result) = stream.next().await { + match chunk_result { + Ok(StreamChunk::Content(text)) => { + print!("{}", text); + io::stdout().flush().unwrap(); + full_output.push_str(&text); + } + Ok(StreamChunk::Done(metadata)) => { + tokens = metadata.usage.map(|u| u.total_tokens); + } + Ok(StreamChunk::Error(err)) => { + println!("\n{}❌ Stream error: {}{}", config.color, err, RESET); + } + Err(e) => { + println!("\n{}❌ Error: {}{}", config.color, e, RESET); + } + } + } + + let duration = start.elapsed(); + + println!("{}", RESET); + println!("\n{}Metadata:{}", DIM, RESET); + println!(" Duration: {}ms", duration.as_millis()); + if let Some(t) = tokens { + println!(" Tokens: {}", t); + } + println!(" Characters streamed: {}", full_output.len()); + println!(); + } + Err(e) => { + println!("{}❌ Error: {}{}\n", config.color, e, RESET); + } + } +} diff --git a/crates/typedialog-agent/typedialog-ag-core/src/cache/mod.rs b/crates/typedialog-agent/typedialog-ag-core/src/cache/mod.rs new file mode 100644 index 0000000..4b17c92 --- /dev/null +++ b/crates/typedialog-agent/typedialog-ag-core/src/cache/mod.rs @@ -0,0 +1,474 @@ +//! Cache layer for agent definitions and transpiled code +//! +//! Two-level cache: +//! - Memory: LRU cache for fast access +//! - Disk: Persistent cache in ~/.typeagent/cache/ + +use crate::error::{Error, Result}; +use lru::LruCache; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; +use std::fs; +use std::num::NonZeroUsize; +use std::path::{Path, PathBuf}; + +/// Cache strategy configuration +#[derive(Debug, Clone)] +pub enum CacheStrategy { + /// Memory-only cache (LRU) + Memory { + /// Max entries in LRU cache + max_entries: usize, + }, + /// Disk-only cache + Disk { + /// Cache directory + cache_dir: PathBuf, + }, + /// Both memory and disk cache + Both { + /// Max entries in memory + max_entries: usize, + /// Disk cache directory + cache_dir: PathBuf, + }, + /// No caching + None, +} + +impl Default for CacheStrategy { + fn default() -> Self { + Self::Both { + max_entries: 1000, + cache_dir: dirs::home_dir() + .unwrap_or_else(|| PathBuf::from(".")) + .join(".typeagent") + .join("cache"), + } + } +} + +/// Cached entry metadata +#[derive(Debug, Clone, Serialize, Deserialize)] +struct CacheEntry { + /// File path (for debugging) + file_path: String, + /// File modification time (Unix timestamp) + mtime: u64, + /// Cached Nickel code + nickel_code: String, +} + +/// Cache manager +pub struct CacheManager { + strategy: CacheStrategy, + /// In-memory LRU cache + memory_cache: Option>, +} + +impl CacheManager { + /// Create a new cache manager with strategy + pub fn new(strategy: CacheStrategy) -> Self { + let memory_cache = match &strategy { + CacheStrategy::Memory { max_entries } | CacheStrategy::Both { max_entries, .. } => { + NonZeroUsize::new(*max_entries).map(LruCache::new) + } + _ => None, + }; + + Self { + strategy, + memory_cache, + } + } + + /// Get cached transpiled Nickel code + /// + /// Returns Some(nickel_code) if cache hit and mtime matches, None otherwise + pub fn get_transpiled(&mut self, file_path: &str, mtime: u64) -> Option { + let cache_key = Self::compute_cache_key(file_path); + + match &self.strategy { + CacheStrategy::None => None, + + CacheStrategy::Memory { .. } => { + // Check memory cache only + self.memory_cache + .as_mut()? + .get(&cache_key) + .and_then(|entry| { + if entry.mtime == mtime { + Some(entry.nickel_code.clone()) + } else { + None + } + }) + } + + CacheStrategy::Disk { cache_dir } => { + // Check disk cache only + self.get_from_disk(cache_dir, &cache_key, mtime) + } + + CacheStrategy::Both { cache_dir, .. } => { + // Check memory first + if let Some(entry) = self.memory_cache.as_mut()?.get(&cache_key) { + if entry.mtime == mtime { + return Some(entry.nickel_code.clone()); + } + } + + // Fall back to disk + let nickel_code = self.get_from_disk(cache_dir, &cache_key, mtime)?; + + // Populate memory cache for next access + if let Some(ref mut mem_cache) = self.memory_cache { + mem_cache.put( + cache_key, + CacheEntry { + file_path: file_path.to_string(), + mtime, + nickel_code: nickel_code.clone(), + }, + ); + } + + Some(nickel_code) + } + } + } + + /// Cache transpiled Nickel code + pub fn put_transpiled(&mut self, file_path: &str, mtime: u64, nickel_code: &str) -> Result<()> { + let cache_key = Self::compute_cache_key(file_path); + let entry = CacheEntry { + file_path: file_path.to_string(), + mtime, + nickel_code: nickel_code.to_string(), + }; + + match &self.strategy { + CacheStrategy::None => Ok(()), + + CacheStrategy::Memory { .. } => { + // Store in memory only + if let Some(ref mut mem_cache) = self.memory_cache { + mem_cache.put(cache_key, entry); + } + Ok(()) + } + + CacheStrategy::Disk { cache_dir } => { + // Store on disk only + self.put_to_disk(cache_dir, &cache_key, &entry) + } + + CacheStrategy::Both { cache_dir, .. } => { + // Store in both + if let Some(ref mut mem_cache) = self.memory_cache { + mem_cache.put(cache_key.clone(), entry.clone()); + } + self.put_to_disk(cache_dir, &cache_key, &entry) + } + } + } + + /// Clear all caches + pub fn clear(&mut self) -> Result<()> { + // Clear memory + if let Some(ref mut mem_cache) = self.memory_cache { + mem_cache.clear(); + } + + // Clear disk + match &self.strategy { + CacheStrategy::Disk { cache_dir } | CacheStrategy::Both { cache_dir, .. } => { + if cache_dir.exists() { + fs::remove_dir_all(cache_dir).map_err(|e| { + Error::cache( + format!("Failed to remove cache directory: {:?}", cache_dir), + e.to_string(), + ) + })?; + } + } + _ => {} + } + + Ok(()) + } + + /// Get cache statistics + pub fn stats(&self) -> CacheStats { + let memory_entries = self.memory_cache.as_ref().map(|c| c.len()).unwrap_or(0); + + let disk_entries = match &self.strategy { + CacheStrategy::Disk { cache_dir } | CacheStrategy::Both { cache_dir, .. } => { + if cache_dir.exists() { + fs::read_dir(cache_dir) + .ok() + .map(|entries| entries.filter_map(|e| e.ok()).count()) + .unwrap_or(0) + } else { + 0 + } + } + _ => 0, + }; + + CacheStats { + memory_entries, + disk_entries, + strategy: format!("{:?}", self.strategy), + } + } + + /// Compute SHA256 hash of file path for cache key + fn compute_cache_key(file_path: &str) -> String { + let mut hasher = Sha256::new(); + hasher.update(file_path.as_bytes()); + hex::encode(hasher.finalize()) + } + + /// Get entry from disk cache + fn get_from_disk(&self, cache_dir: &Path, cache_key: &str, mtime: u64) -> Option { + let cache_file = cache_dir.join(cache_key); + + if !cache_file.exists() { + return None; + } + + // Read and deserialize + let content = fs::read_to_string(&cache_file).ok()?; + let entry: CacheEntry = serde_json::from_str(&content).ok()?; + + // Validate mtime + if entry.mtime != mtime { + // Stale cache - remove it (ignore errors) + fs::remove_file(&cache_file).ok(); + return None; + } + + Some(entry.nickel_code) + } + + /// Write entry to disk cache + fn put_to_disk(&self, cache_dir: &Path, cache_key: &str, entry: &CacheEntry) -> Result<()> { + // Ensure cache directory exists + if !cache_dir.exists() { + fs::create_dir_all(cache_dir).map_err(|e| { + Error::cache( + format!("Failed to create cache directory: {:?}", cache_dir), + e.to_string(), + ) + })?; + } + + let cache_file = cache_dir.join(cache_key); + + // Serialize and write + let content = serde_json::to_string_pretty(entry).map_err(|e| { + Error::cache( + format!("Failed to serialize cache entry for: {}", entry.file_path), + e.to_string(), + ) + })?; + + fs::write(&cache_file, content).map_err(|e| { + Error::cache( + format!("Failed to write cache file: {:?}", cache_file), + e.to_string(), + ) + })?; + + Ok(()) + } +} + +impl Default for CacheManager { + fn default() -> Self { + Self::new(CacheStrategy::default()) + } +} + +/// Cache statistics +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CacheStats { + /// Number of entries in memory cache + pub memory_entries: usize, + /// Number of entries in disk cache + pub disk_entries: usize, + /// Cache strategy description + pub strategy: String, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_compute_cache_key() { + let key1 = CacheManager::compute_cache_key("/path/to/agent.mdx"); + let key2 = CacheManager::compute_cache_key("/path/to/agent.mdx"); + let key3 = CacheManager::compute_cache_key("/different/path.mdx"); + + // Same path produces same key + assert_eq!(key1, key2); + // Different path produces different key + assert_ne!(key1, key3); + // Keys are hex-encoded SHA256 (64 chars) + assert_eq!(key1.len(), 64); + } + + #[test] + fn test_memory_cache_hit() { + let mut cache = CacheManager::new(CacheStrategy::Memory { max_entries: 10 }); + + let path = "/test/agent.mdx"; + let mtime = 12345u64; + let nickel = "{ config = { role = \"test\" } }"; + + // Put to cache + cache.put_transpiled(path, mtime, nickel).unwrap(); + + // Get from cache - should hit + let result = cache.get_transpiled(path, mtime); + assert_eq!(result, Some(nickel.to_string())); + } + + #[test] + fn test_memory_cache_miss_stale_mtime() { + let mut cache = CacheManager::new(CacheStrategy::Memory { max_entries: 10 }); + + let path = "/test/agent.mdx"; + let mtime = 12345u64; + let nickel = "{ config = { role = \"test\" } }"; + + // Put to cache + cache.put_transpiled(path, mtime, nickel).unwrap(); + + // Get with different mtime - should miss + let result = cache.get_transpiled(path, mtime + 1); + assert_eq!(result, None); + } + + #[test] + fn test_disk_cache_roundtrip() { + let temp_dir = std::env::temp_dir().join("typeagent-test-cache"); + let mut cache = CacheManager::new(CacheStrategy::Disk { + cache_dir: temp_dir.clone(), + }); + + let path = "/test/agent.mdx"; + let mtime = 12345u64; + let nickel = "{ config = { role = \"test\" } }"; + + // Put to cache + cache.put_transpiled(path, mtime, nickel).unwrap(); + + // Get from cache - should hit + let result = cache.get_transpiled(path, mtime); + assert_eq!(result, Some(nickel.to_string())); + + // Cleanup + cache.clear().unwrap(); + } + + #[test] + fn test_both_cache_strategy() { + let temp_dir = std::env::temp_dir().join("typeagent-test-cache-both"); + let mut cache = CacheManager::new(CacheStrategy::Both { + max_entries: 10, + cache_dir: temp_dir.clone(), + }); + + let path = "/test/agent.mdx"; + let mtime = 12345u64; + let nickel = "{ config = { role = \"test\" } }"; + + // Put to cache (should go to both memory and disk) + cache.put_transpiled(path, mtime, nickel).unwrap(); + + // Clear memory cache + if let Some(ref mut mem_cache) = cache.memory_cache { + mem_cache.clear(); + } + + // Get from cache - should hit from disk and repopulate memory + let result = cache.get_transpiled(path, mtime); + assert_eq!(result, Some(nickel.to_string())); + + // Memory should now have the entry + assert_eq!(cache.memory_cache.as_ref().unwrap().len(), 1); + + // Cleanup + cache.clear().unwrap(); + } + + #[test] + fn test_cache_stats() { + let temp_dir = std::env::temp_dir().join("typeagent-test-cache-stats"); + let mut cache = CacheManager::new(CacheStrategy::Both { + max_entries: 10, + cache_dir: temp_dir.clone(), + }); + + // Initially empty + let stats = cache.stats(); + assert_eq!(stats.memory_entries, 0); + assert_eq!(stats.disk_entries, 0); + + // Add entries + cache.put_transpiled("/test/1.mdx", 111, "code1").unwrap(); + cache.put_transpiled("/test/2.mdx", 222, "code2").unwrap(); + + let stats = cache.stats(); + assert_eq!(stats.memory_entries, 2); + assert_eq!(stats.disk_entries, 2); + + // Cleanup + cache.clear().unwrap(); + } + + #[test] + fn test_clear_cache() { + let temp_dir = std::env::temp_dir().join("typeagent-test-cache-clear"); + let mut cache = CacheManager::new(CacheStrategy::Both { + max_entries: 10, + cache_dir: temp_dir.clone(), + }); + + // Add entries + cache.put_transpiled("/test/1.mdx", 111, "code1").unwrap(); + cache.put_transpiled("/test/2.mdx", 222, "code2").unwrap(); + + // Verify they exist + let stats = cache.stats(); + assert_eq!(stats.memory_entries, 2); + assert_eq!(stats.disk_entries, 2); + + // Clear + cache.clear().unwrap(); + + // Verify empty + let stats = cache.stats(); + assert_eq!(stats.memory_entries, 0); + assert_eq!(stats.disk_entries, 0); + } + + #[test] + fn test_none_strategy() { + let mut cache = CacheManager::new(CacheStrategy::None); + + let path = "/test/agent.mdx"; + let mtime = 12345u64; + let nickel = "{ config = { role = \"test\" } }"; + + // Put should succeed but do nothing + cache.put_transpiled(path, mtime, nickel).unwrap(); + + // Get should always return None + let result = cache.get_transpiled(path, mtime); + assert_eq!(result, None); + } +} diff --git a/crates/typedialog-agent/typedialog-ag-core/src/error.rs b/crates/typedialog-agent/typedialog-ag-core/src/error.rs new file mode 100644 index 0000000..9eb5f60 --- /dev/null +++ b/crates/typedialog-agent/typedialog-ag-core/src/error.rs @@ -0,0 +1,260 @@ +//! Error types for TypeAgent +//! +//! Following M-ERRORS-CANONICAL-STRUCTS: situation-specific error struct with kind enum + +use std::error::Error as StdError; +use std::fmt; + +pub type Result = std::result::Result; + +/// TypeAgent error with context and source chaining +#[derive(Debug)] +pub struct Error { + pub kind: ErrorKind, + pub context: String, + pub source: Option>, +} + +/// Error kind taxonomy +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ErrorKind { + /// MDX/Markdown parsing failed + Parse { + line: Option, + column: Option, + }, + /// AST to Nickel transpilation failed + Transpile { detail: String }, + /// Nickel type checking or evaluation failed + NickelEval { detail: String }, + /// Agent execution failed + Execution { stage: String }, + /// I/O operation failed + Io, + /// JSON/YAML serialization failed + Serialization, + /// Cache operation failed + Cache { operation: String }, + /// Output validation failed + Validation { rule: String }, + /// File format not supported + UnsupportedFormat { extension: String }, + /// File not found + NotFound { path: String }, +} + +impl Error { + /// Create parse error with location + pub fn parse(message: impl Into, line: Option, column: Option) -> Self { + Self { + kind: ErrorKind::Parse { line, column }, + context: message.into(), + source: None, + } + } + + /// Create transpile error + pub fn transpile(message: impl Into, detail: impl Into) -> Self { + Self { + kind: ErrorKind::Transpile { + detail: detail.into(), + }, + context: message.into(), + source: None, + } + } + + /// Create Nickel evaluation error + pub fn nickel_eval(message: impl Into, detail: impl Into) -> Self { + Self { + kind: ErrorKind::NickelEval { + detail: detail.into(), + }, + context: message.into(), + source: None, + } + } + + /// Create execution error + pub fn execution(message: impl Into, stage: impl Into) -> Self { + Self { + kind: ErrorKind::Execution { + stage: stage.into(), + }, + context: message.into(), + source: None, + } + } + + /// Create I/O error + pub fn io(message: impl Into, detail: impl Into) -> Self { + Self { + kind: ErrorKind::Io, + context: format!("{}: {}", message.into(), detail.into()), + source: None, + } + } + + /// Create cache error + pub fn cache(message: impl Into, operation: impl Into) -> Self { + Self { + kind: ErrorKind::Cache { + operation: operation.into(), + }, + context: message.into(), + source: None, + } + } + + /// Create validation error + pub fn validation(message: impl Into, rule: impl Into) -> Self { + Self { + kind: ErrorKind::Validation { rule: rule.into() }, + context: message.into(), + source: None, + } + } + + /// Create unsupported format error + pub fn unsupported_format(extension: impl Into) -> Self { + Self { + kind: ErrorKind::UnsupportedFormat { + extension: extension.into(), + }, + context: "Unsupported agent file format".to_string(), + source: None, + } + } + + /// Create file not found error + pub fn not_found(path: impl Into) -> Self { + Self { + kind: ErrorKind::NotFound { path: path.into() }, + context: "Agent file not found".to_string(), + source: None, + } + } + + /// Add source error for chaining + pub fn with_source(mut self, source: impl StdError + Send + Sync + 'static) -> Self { + self.source = Some(Box::new(source)); + self + } + + /// Check if error is parse error + pub fn is_parse(&self) -> bool { + matches!(self.kind, ErrorKind::Parse { .. }) + } + + /// Check if error is transpile error + pub fn is_transpile(&self) -> bool { + matches!(self.kind, ErrorKind::Transpile { .. }) + } + + /// Check if error is Nickel evaluation error + pub fn is_nickel_eval(&self) -> bool { + matches!(self.kind, ErrorKind::NickelEval { .. }) + } + + /// Check if error is execution error + pub fn is_execution(&self) -> bool { + matches!(self.kind, ErrorKind::Execution { .. }) + } + + /// Check if error is validation error + pub fn is_validation(&self) -> bool { + matches!(self.kind, ErrorKind::Validation { .. }) + } + + /// Get parse location if available + pub fn parse_location(&self) -> Option<(Option, Option)> { + if let ErrorKind::Parse { line, column } = &self.kind { + Some((*line, *column)) + } else { + None + } + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.kind { + ErrorKind::Parse { line, column } => { + write!(f, "Parse error")?; + if let Some(l) = line { + write!(f, " at line {}", l)?; + } + if let Some(c) = column { + write!(f, ", column {}", c)?; + } + write!(f, ": {}", self.context) + } + ErrorKind::Transpile { detail } => { + write!(f, "Transpile error ({}): {}", detail, self.context) + } + ErrorKind::NickelEval { detail } => { + write!(f, "Nickel evaluation error ({}): {}", detail, self.context) + } + ErrorKind::Execution { stage } => { + write!(f, "Execution error at {}: {}", stage, self.context) + } + ErrorKind::Io => { + write!(f, "I/O error: {}", self.context) + } + ErrorKind::Serialization => { + write!(f, "Serialization error: {}", self.context) + } + ErrorKind::Cache { operation } => { + write!(f, "Cache error during {}: {}", operation, self.context) + } + ErrorKind::Validation { rule } => { + write!(f, "Validation failed ({}): {}", rule, self.context) + } + ErrorKind::UnsupportedFormat { extension } => { + write!(f, "Unsupported format '{}': {}", extension, self.context) + } + ErrorKind::NotFound { path } => { + write!(f, "File not found '{}': {}", path, self.context) + } + } + } +} + +impl StdError for Error { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + self.source.as_ref().map(|e| e.as_ref() as &dyn StdError) + } +} + +/// Conversion from std::io::Error +impl From for Error { + fn from(err: std::io::Error) -> Self { + Self { + kind: ErrorKind::Io, + context: err.to_string(), + source: Some(Box::new(err)), + } + } +} + +/// Conversion from serde_json::Error +impl From for Error { + fn from(err: serde_json::Error) -> Self { + Self { + kind: ErrorKind::Serialization, + context: err.to_string(), + source: Some(Box::new(err)), + } + } +} + +/// Conversion from serde_yaml::Error +impl From for Error { + fn from(err: serde_yaml::Error) -> Self { + Self { + kind: ErrorKind::Serialization, + context: err.to_string(), + source: Some(Box::new(err)), + } + } +} diff --git a/crates/typedialog-agent/typedialog-ag-core/src/executor/mod.rs b/crates/typedialog-agent/typedialog-ag-core/src/executor/mod.rs new file mode 100644 index 0000000..32d9ad7 --- /dev/null +++ b/crates/typedialog-agent/typedialog-ag-core/src/executor/mod.rs @@ -0,0 +1,829 @@ +//! Agent execution layer +//! +//! Executes agents with context injection and output validation + +use crate::error::{Error, Result}; +use crate::nickel::AgentDefinition; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::process::Command; +use tera::{Context, Tera}; + +/// Agent execution result +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ExecutionResult { + /// Generated output from agent + pub output: String, + /// Whether validation passed + pub validation_passed: bool, + /// Validation errors if any + #[serde(default)] + pub validation_errors: Vec, + /// Execution metadata + #[serde(default)] + pub metadata: ExecutionMetadata, +} + +/// Execution metadata +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct ExecutionMetadata { + /// Time taken in milliseconds + pub duration_ms: Option, + /// Token count if available + pub tokens: Option, + /// Model used + pub model: Option, +} + +/// Context data assembled from imports and shell commands +#[derive(Debug, Clone)] +pub struct ContextData { + /// File imports content (alias → content) + pub imports: HashMap, + /// Shell command outputs (alias → output) + pub shell_outputs: HashMap, +} + +/// Agent executor +pub struct AgentExecutor; + +impl AgentExecutor { + /// Create a new agent executor + pub fn new() -> Self { + Self + } + + /// Execute an agent with given inputs + pub async fn execute( + &self, + agent: &AgentDefinition, + inputs: HashMap, + ) -> Result { + let start = std::time::Instant::now(); + + // Load context data + let context_data = self.load_context(agent).await?; + + // Render template with inputs and context + let rendered_prompt = self.render_template(agent, &inputs, &context_data)?; + + // Execute LLM with rendered prompt + let (output, tokens) = self.execute_llm(agent, &rendered_prompt).await?; + + // Validate output + let validation_errors = self.validate_output(&output, agent)?; + let validation_passed = validation_errors.is_empty(); + + let duration_ms = start.elapsed().as_millis() as u64; + + Ok(ExecutionResult { + output, + validation_passed, + validation_errors, + metadata: ExecutionMetadata { + duration_ms: Some(duration_ms), + tokens, + model: Some(agent.config.llm.clone()), + }, + }) + } + + /// Execute an agent with streaming output + /// + /// The callback is called for each chunk of output as it arrives. + /// Returns the final execution result when the stream completes. + pub async fn execute_streaming( + &self, + agent: &AgentDefinition, + inputs: HashMap, + mut on_chunk: F, + ) -> Result + where + F: FnMut(&str), + { + use crate::llm::{create_provider, LlmMessage, LlmRequest, MessageRole, StreamChunk}; + use futures::stream::StreamExt; + + let start = std::time::Instant::now(); + + // Load context data + let context_data = self.load_context(agent).await?; + + // Render template with inputs and context + let rendered_prompt = self.render_template(agent, &inputs, &context_data)?; + + // Create provider and build request + let provider = create_provider(&agent.config.llm)?; + let request = LlmRequest { + model: agent.config.llm.clone(), + messages: vec![LlmMessage { + role: MessageRole::User, + content: rendered_prompt.clone(), + }], + max_tokens: Some(agent.config.max_tokens), + temperature: Some(agent.config.temperature), + system: Some(format!("You are a {}.", agent.config.role)), + }; + + // Stream the response + let mut stream = provider.stream(request).await?; + let mut full_output = String::new(); + let mut tokens = None; + + while let Some(chunk_result) = stream.next().await { + match chunk_result? { + StreamChunk::Content(text) => { + full_output.push_str(&text); + on_chunk(&text); + } + StreamChunk::Done(metadata) => { + tokens = metadata.usage.map(|u| u.total_tokens); + } + StreamChunk::Error(err) => { + return Err(Error::execution("Streaming error", err)); + } + } + } + + // Validate output + let validation_errors = self.validate_output(&full_output, agent)?; + let validation_passed = validation_errors.is_empty(); + + let duration_ms = start.elapsed().as_millis() as u64; + + Ok(ExecutionResult { + output: full_output, + validation_passed, + validation_errors, + metadata: ExecutionMetadata { + duration_ms: Some(duration_ms), + tokens, + model: Some(agent.config.llm.clone()), + }, + }) + } + + /// Execute LLM with the rendered prompt + async fn execute_llm( + &self, + agent: &AgentDefinition, + prompt: &str, + ) -> Result<(String, Option)> { + use crate::llm::{create_provider, LlmMessage, LlmRequest, MessageRole}; + + // Create provider based on model name + let provider = create_provider(&agent.config.llm)?; + + // Build request + let request = LlmRequest { + model: agent.config.llm.clone(), + messages: vec![LlmMessage { + role: MessageRole::User, + content: prompt.to_string(), + }], + max_tokens: Some(agent.config.max_tokens), + temperature: Some(agent.config.temperature), + system: Some(format!("You are a {}.", agent.config.role)), + }; + + // Execute request + let response = provider.complete(request).await?; + + // Extract token count + let tokens = response.usage.as_ref().map(|u| u.total_tokens); + + Ok((response.content, tokens)) + } + + /// Load context from imports and shell commands + async fn load_context(&self, agent: &AgentDefinition) -> Result { + let mut imports = HashMap::new(); + let mut shell_outputs = HashMap::new(); + + if let Some(context) = &agent.context { + // Load file imports + for import_path in &context.imports { + let content = self.load_import(import_path).await?; + imports.insert(import_path.clone(), content); + } + + // Execute shell commands + for shell_cmd in &context.shell_commands { + let output = self.execute_shell_command(shell_cmd)?; + shell_outputs.insert(shell_cmd.clone(), output); + } + } + + Ok(ContextData { + imports, + shell_outputs, + }) + } + + /// Load content from import path (file, glob, or URL) + async fn load_import(&self, path: &str) -> Result { + // Check if it's a URL + if path.starts_with("http://") || path.starts_with("https://") { + return self.load_url(path).await; + } + + // Check if it's a glob pattern + if path.contains('*') { + return self.load_glob(path); + } + + // Regular file path + std::fs::read_to_string(path).map_err(|e| { + Error::execution( + format!("Failed to read import file: {}", path), + e.to_string(), + ) + }) + } + + /// Load content from URL + async fn load_url(&self, url: &str) -> Result { + let response = reqwest::get(url).await.map_err(|e| { + Error::execution(format!("Failed to fetch URL: {}", url), e.to_string()) + })?; + + response.text().await.map_err(|e| { + Error::execution( + format!("Failed to read response from URL: {}", url), + e.to_string(), + ) + }) + } + + /// Load files matching glob pattern + fn load_glob(&self, pattern: &str) -> Result { + use globset::GlobBuilder; + use ignore::WalkBuilder; + + let glob = GlobBuilder::new(pattern) + .build() + .map_err(|e| { + Error::execution(format!("Invalid glob pattern: {}", pattern), e.to_string()) + })? + .compile_matcher(); + + let mut contents = Vec::new(); + + // Extract base path from pattern + let base_path = pattern.split("**").next().unwrap_or("."); + + for entry in WalkBuilder::new(base_path).build() { + let entry = entry.map_err(|e| { + Error::execution( + format!("Failed to walk directory: {}", base_path), + e.to_string(), + ) + })?; + + if entry.file_type().map(|ft| ft.is_file()).unwrap_or(false) + && glob.is_match(entry.path()) + { + let content = std::fs::read_to_string(entry.path()).map_err(|e| { + Error::execution( + format!("Failed to read file: {:?}", entry.path()), + e.to_string(), + ) + })?; + + contents.push(format!("// File: {}\n{}", entry.path().display(), content)); + } + } + + if contents.is_empty() { + return Err(Error::execution( + format!("No files matched glob pattern: {}", pattern), + "glob returned empty", + )); + } + + Ok(contents.join("\n\n")) + } + + /// Execute shell command and capture output + fn execute_shell_command(&self, command: &str) -> Result { + let output = Command::new("sh") + .arg("-c") + .arg(command) + .output() + .map_err(|e| { + Error::execution( + format!("Failed to execute shell command: {}", command), + e.to_string(), + ) + })?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + return Err(Error::execution( + format!("Shell command failed: {}", command), + stderr.to_string(), + )); + } + + Ok(String::from_utf8_lossy(&output.stdout).to_string()) + } + + /// Render template with inputs and context + fn render_template( + &self, + agent: &AgentDefinition, + inputs: &HashMap, + context_data: &ContextData, + ) -> Result { + let mut tera = Tera::default(); + + // Add template + tera.add_raw_template("agent", &agent.template) + .map_err(|e| Error::execution("Failed to parse template", e.to_string()))?; + + // Build context + let mut ctx = Context::new(); + + // Add inputs + for (key, value) in inputs { + ctx.insert(key, value); + } + + // Add agent config + ctx.insert("role", &agent.config.role); + ctx.insert("llm", &agent.config.llm); + ctx.insert("tools", &agent.config.tools); + + // Add imports + for (alias, content) in &context_data.imports { + ctx.insert(alias, content); + } + + // Add shell outputs + for (alias, output) in &context_data.shell_outputs { + ctx.insert(alias, output); + } + + // Render + tera.render("agent", &ctx) + .map_err(|e| Error::execution("Failed to render template", e.to_string())) + } + + /// Validate output against agent validation rules + pub fn validate_output(&self, output: &str, agent: &AgentDefinition) -> Result> { + let mut errors = Vec::new(); + + if let Some(validation) = &agent.validation { + // Check must_contain + for pattern in &validation.must_contain { + if !output.contains(pattern) { + errors.push(format!("Output must contain: {}", pattern)); + } + } + + // Check must_not_contain + for pattern in &validation.must_not_contain { + if output.contains(pattern) { + errors.push(format!("Output must not contain: {}", pattern)); + } + } + + // Check min_length + if let Some(min_len) = validation.min_length { + if output.len() < min_len { + errors.push(format!( + "Output too short: {} chars (minimum: {})", + output.len(), + min_len + )); + } + } + + // Check max_length + if let Some(max_len) = validation.max_length { + if output.len() > max_len { + errors.push(format!( + "Output too long: {} chars (maximum: {})", + output.len(), + max_len + )); + } + } + + // Check format + match validation.format.as_str() { + "json" => { + if serde_json::from_str::(output).is_err() { + errors.push("Output is not valid JSON".to_string()); + } + } + "yaml" => { + if serde_yaml::from_str::(output).is_err() { + errors.push("Output is not valid YAML".to_string()); + } + } + _ => { + // Markdown, text, and other formats are permissive - any text is valid + } + } + } + + Ok(errors) + } +} + +impl Default for AgentExecutor { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::nickel::{AgentConfig, ValidationRules}; + + #[test] + fn test_validate_output_must_contain() { + let executor = AgentExecutor::new(); + let agent = AgentDefinition { + config: AgentConfig { + role: "tester".to_string(), + llm: "claude".to_string(), + tools: vec![], + max_tokens: 4096, + temperature: 0.7, + }, + inputs: HashMap::new(), + context: None, + validation: Some(ValidationRules { + must_contain: vec!["test".to_string(), "pass".to_string()], + must_not_contain: vec![], + format: "markdown".to_string(), + min_length: None, + max_length: None, + }), + template: "".to_string(), + }; + + let output = "This is a test and it should pass"; + let errors = executor.validate_output(output, &agent).unwrap(); + assert!(errors.is_empty()); + + let output_bad = "This is missing something"; + let errors = executor.validate_output(output_bad, &agent).unwrap(); + assert!(!errors.is_empty()); + assert!(errors[0].contains("must contain")); + } + + #[test] + fn test_validate_output_length() { + let executor = AgentExecutor::new(); + let agent = AgentDefinition { + config: AgentConfig { + role: "tester".to_string(), + llm: "claude".to_string(), + tools: vec![], + max_tokens: 4096, + temperature: 0.7, + }, + inputs: HashMap::new(), + context: None, + validation: Some(ValidationRules { + must_contain: vec![], + must_not_contain: vec![], + format: "markdown".to_string(), + min_length: Some(10), + max_length: Some(50), + }), + template: "".to_string(), + }; + + let output = "Just right length"; + let errors = executor.validate_output(output, &agent).unwrap(); + assert!(errors.is_empty()); + + let output_short = "Short"; + let errors = executor.validate_output(output_short, &agent).unwrap(); + assert!(!errors.is_empty()); + assert!(errors[0].contains("too short")); + + let output_long = "This is way too long for the maximum length constraint that we've set"; + let errors = executor.validate_output(output_long, &agent).unwrap(); + assert!(!errors.is_empty()); + assert!(errors[0].contains("too long")); + } + + #[test] + fn test_validate_output_json_format() { + let executor = AgentExecutor::new(); + let agent = AgentDefinition { + config: AgentConfig { + role: "tester".to_string(), + llm: "claude".to_string(), + tools: vec![], + max_tokens: 4096, + temperature: 0.7, + }, + inputs: HashMap::new(), + context: None, + validation: Some(ValidationRules { + must_contain: vec![], + must_not_contain: vec![], + format: "json".to_string(), + min_length: None, + max_length: None, + }), + template: "".to_string(), + }; + + let output_valid = r#"{"key": "value"}"#; + let errors = executor.validate_output(output_valid, &agent).unwrap(); + assert!(errors.is_empty()); + + let output_invalid = "not json at all"; + let errors = executor.validate_output(output_invalid, &agent).unwrap(); + assert!(!errors.is_empty()); + assert!(errors[0].contains("not valid JSON")); + } + + #[test] + fn test_render_template_basic() { + let executor = AgentExecutor::new(); + let agent = AgentDefinition { + config: AgentConfig { + role: "architect".to_string(), + llm: "claude-opus-4".to_string(), + tools: vec![], + max_tokens: 4096, + temperature: 0.7, + }, + inputs: HashMap::new(), + context: None, + validation: None, + template: "Design feature: {{ feature_name }}".to_string(), + }; + + let mut inputs = HashMap::new(); + inputs.insert("feature_name".to_string(), serde_json::json!("auth")); + + let context_data = ContextData { + imports: HashMap::new(), + shell_outputs: HashMap::new(), + }; + + let result = executor.render_template(&agent, &inputs, &context_data); + assert!(result.is_ok()); + let rendered = result.unwrap(); + assert_eq!(rendered, "Design feature: auth"); + } + + #[tokio::test] + async fn test_execute_shell_command() { + let executor = AgentExecutor::new(); + let result = executor.execute_shell_command("echo hello"); + assert!(result.is_ok()); + assert_eq!(result.unwrap().trim(), "hello"); + } + + #[tokio::test] + #[ignore] // Requires ANTHROPIC_API_KEY + async fn test_execute_with_real_llm() { + let executor = AgentExecutor::new(); + let agent = AgentDefinition { + config: AgentConfig { + role: "assistant".to_string(), + llm: "claude-3-5-haiku-20241022".to_string(), + tools: vec![], + max_tokens: 100, + temperature: 0.7, + }, + inputs: HashMap::new(), + context: None, + validation: None, + template: "Say hello to {{ name }} in exactly 3 words.".to_string(), + }; + + let mut inputs = HashMap::new(); + inputs.insert("name".to_string(), serde_json::json!("Alice")); + + let result = executor.execute(&agent, inputs).await; + assert!(result.is_ok()); + + let exec_result = result.unwrap(); + assert!(!exec_result.output.is_empty()); + assert!(exec_result.metadata.tokens.is_some()); + + println!("LLM Response: {}", exec_result.output); + println!("Tokens used: {:?}", exec_result.metadata.tokens); + } + + // Mock provider for testing streaming without API calls + mod mock { + use super::*; + use crate::llm::{ + LlmProvider, LlmRequest, LlmResponse, LlmStream, StreamChunk, StreamMetadata, + TokenUsage, + }; + use async_trait::async_trait; + use futures::stream; + + pub struct MockStreamingProvider { + chunks: Vec, + should_error: bool, + } + + impl MockStreamingProvider { + pub fn new(chunks: Vec) -> Self { + Self { + chunks, + should_error: false, + } + } + + pub fn with_error() -> Self { + Self { + chunks: vec![], + should_error: true, + } + } + } + + #[async_trait] + impl LlmProvider for MockStreamingProvider { + async fn complete(&self, _request: LlmRequest) -> Result { + Ok(LlmResponse { + content: self.chunks.join(""), + model: "mock-model".to_string(), + usage: Some(TokenUsage { + input_tokens: 10, + output_tokens: 20, + total_tokens: 30, + }), + }) + } + + async fn stream(&self, _request: LlmRequest) -> Result { + if self.should_error { + let error_stream = + stream::iter(vec![Ok(StreamChunk::Error("Mock error".to_string()))]); + return Ok(Box::pin(error_stream)); + } + + let mut events: Vec> = self + .chunks + .iter() + .map(|chunk| Ok(StreamChunk::Content(chunk.clone()))) + .collect(); + + // Add final Done event with usage + events.push(Ok(StreamChunk::Done(StreamMetadata { + model: "mock-model".to_string(), + usage: Some(TokenUsage { + input_tokens: 10, + output_tokens: 20, + total_tokens: 30, + }), + }))); + + Ok(Box::pin(stream::iter(events))) + } + + fn name(&self) -> &str { + "MockStreaming" + } + } + } + + #[tokio::test] + async fn test_streaming_chunks_accumulate() { + use crate::llm::LlmProvider; + use std::sync::{Arc, Mutex}; + + // Track chunks as they arrive + let received_chunks = Arc::new(Mutex::new(Vec::new())); + let chunks_clone = Arc::clone(&received_chunks); + + // Mock provider that streams "Hello", " ", "world" + let mock_provider = mock::MockStreamingProvider::new(vec![ + "Hello".to_string(), + " ".to_string(), + "world".to_string(), + ]); + + // Manually test streaming with mock provider + use crate::llm::{LlmMessage, LlmRequest, MessageRole, StreamChunk}; + use futures::stream::StreamExt; + + let request = LlmRequest { + model: "mock-model".to_string(), + messages: vec![LlmMessage { + role: MessageRole::User, + content: "Test".to_string(), + }], + max_tokens: Some(100), + temperature: Some(0.7), + system: Some("Test system".to_string()), + }; + + let mut stream = mock_provider.stream(request).await.unwrap(); + let mut full_output = String::new(); + let mut tokens = None; + + while let Some(chunk_result) = stream.next().await { + match chunk_result.unwrap() { + StreamChunk::Content(text) => { + chunks_clone.lock().unwrap().push(text.clone()); + full_output.push_str(&text); + } + StreamChunk::Done(metadata) => { + tokens = metadata.usage.map(|u| u.total_tokens); + } + StreamChunk::Error(_) => panic!("Unexpected error"), + } + } + + // Verify chunks were received in order + let chunks = received_chunks.lock().unwrap(); + assert_eq!(chunks.len(), 3); + assert_eq!(chunks[0], "Hello"); + assert_eq!(chunks[1], " "); + assert_eq!(chunks[2], "world"); + + // Verify full output + assert_eq!(full_output, "Hello world"); + + // Verify tokens + assert_eq!(tokens, Some(30)); + } + + #[tokio::test] + async fn test_streaming_error_handling() { + use crate::llm::{LlmMessage, LlmProvider, LlmRequest, MessageRole, StreamChunk}; + use futures::stream::StreamExt; + + let mock_provider = mock::MockStreamingProvider::with_error(); + + let request = LlmRequest { + model: "mock-model".to_string(), + messages: vec![LlmMessage { + role: MessageRole::User, + content: "Test".to_string(), + }], + max_tokens: Some(100), + temperature: Some(0.7), + system: Some("Test system".to_string()), + }; + + let mut stream = mock_provider.stream(request).await.unwrap(); + let mut error_received = false; + + while let Some(chunk_result) = stream.next().await { + if let StreamChunk::Error(msg) = chunk_result.unwrap() { + assert_eq!(msg, "Mock error"); + error_received = true; + } + } + + assert!(error_received); + } + + #[tokio::test] + async fn test_streaming_empty_chunks() { + use crate::llm::{LlmMessage, LlmProvider, LlmRequest, MessageRole, StreamChunk}; + use futures::stream::StreamExt; + + // Provider with no content chunks, just Done + let mock_provider = mock::MockStreamingProvider::new(vec![]); + + let request = LlmRequest { + model: "mock-model".to_string(), + messages: vec![LlmMessage { + role: MessageRole::User, + content: "Test".to_string(), + }], + max_tokens: Some(100), + temperature: Some(0.7), + system: Some("Test system".to_string()), + }; + + let mut stream = mock_provider.stream(request).await.unwrap(); + let mut full_output = String::new(); + let mut done_received = false; + + while let Some(chunk_result) = stream.next().await { + match chunk_result.unwrap() { + StreamChunk::Content(text) => { + full_output.push_str(&text); + } + StreamChunk::Done(_) => { + done_received = true; + } + StreamChunk::Error(_) => panic!("Unexpected error"), + } + } + + assert_eq!(full_output, ""); + assert!(done_received); + } +} diff --git a/crates/typedialog-agent/typedialog-ag-core/src/formats/mod.rs b/crates/typedialog-agent/typedialog-ag-core/src/formats/mod.rs new file mode 100644 index 0000000..efe356b --- /dev/null +++ b/crates/typedialog-agent/typedialog-ag-core/src/formats/mod.rs @@ -0,0 +1,39 @@ +//! Format detection for agent files +//! +//! Supports three agent file formats: +//! - `.agent.mdx`: Markdown with @directives (primary format) +//! - `.agent.ncl`: Pure Nickel configuration +//! - `.agent.md`: Legacy markdown with YAML frontmatter + +use std::path::Path; + +/// Agent file format +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum AgentFormat { + /// Markdown with @directives (.agent.mdx) + MarkdownExtended, + /// Pure Nickel (.agent.ncl) + Nickel, + /// Legacy markdown with YAML frontmatter (.agent.md) + Markdown, +} + +/// Format detector based on file extension +pub struct FormatDetector; + +impl FormatDetector { + /// Detect agent format from file path extension + pub fn detect(path: &Path) -> Option { + let filename = path.file_name()?.to_str()?; + + if filename.ends_with(".agent.mdx") { + Some(AgentFormat::MarkdownExtended) + } else if filename.ends_with(".agent.ncl") { + Some(AgentFormat::Nickel) + } else if filename.ends_with(".agent.md") { + Some(AgentFormat::Markdown) + } else { + None + } + } +} diff --git a/crates/typedialog-agent/typedialog-ag-core/src/lib.rs b/crates/typedialog-agent/typedialog-ag-core/src/lib.rs new file mode 100644 index 0000000..f5f320a --- /dev/null +++ b/crates/typedialog-agent/typedialog-ag-core/src/lib.rs @@ -0,0 +1,183 @@ +#![allow(clippy::module_name_repetitions)] +#![allow(clippy::too_many_arguments)] + +//! TypeAgent Core Library +//! +//! Type-safe AI agent execution with 3-layer pipeline: +//! - Layer 1: MDX → AST (markup parsing) +//! - Layer 2: AST → Nickel (transpilation + type checking) +//! - Layer 3: Nickel → Output (execution + validation) + +pub mod cache; +pub mod error; +pub mod executor; +pub mod formats; +pub mod llm; +pub mod nickel; +pub mod parser; +pub mod transpiler; +pub mod utils; + +// Public API exports +pub use cache::{CacheManager, CacheStats, CacheStrategy}; +pub use error::{Error, Result}; +pub use executor::{AgentExecutor, ExecutionResult}; +pub use formats::{AgentFormat, FormatDetector}; +pub use nickel::{AgentConfig, AgentDefinition, NickelEvaluator}; +pub use parser::{AgentDirective, MarkupNode, MarkupParser}; +pub use transpiler::NickelTranspiler; + +/// Agent loader - main entry point +pub struct AgentLoader { + parser: MarkupParser, + transpiler: NickelTranspiler, + evaluator: NickelEvaluator, + executor: AgentExecutor, + cache: Option>>, +} + +impl AgentLoader { + /// Create new agent loader + pub fn new() -> Self { + Self { + parser: MarkupParser::new(), + transpiler: NickelTranspiler::new(), + evaluator: NickelEvaluator::new(), + executor: AgentExecutor::new(), + cache: Some(std::sync::Arc::new(std::sync::Mutex::new( + CacheManager::default(), + ))), + } + } + + /// Create without cache + pub fn without_cache() -> Self { + Self { + parser: MarkupParser::new(), + transpiler: NickelTranspiler::new(), + evaluator: NickelEvaluator::new(), + executor: AgentExecutor::new(), + cache: None, + } + } + + /// Load agent from file + /// + /// Executes the 3-layer pipeline: + /// 1. Parse MDX → AST + /// 2. Transpile AST → Nickel + /// 3. Evaluate Nickel → AgentDefinition + pub async fn load(&self, path: &std::path::Path) -> Result { + // Read file content + let content = std::fs::read_to_string(path).map_err(|e| { + Error::io( + format!("Failed to read agent file: {:?}", path), + e.to_string(), + ) + })?; + + // Check cache for transpiled Nickel code + let file_mtime = std::fs::metadata(path) + .ok() + .and_then(|m| m.modified().ok()) + .and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok()) + .map(|d| d.as_secs()) + .unwrap_or(0); + + let path_str = path.to_string_lossy().to_string(); + + // Try to get from cache + let nickel_code = if let Some(cache_arc) = &self.cache { + // Try to get from cache first + let mut cache = cache_arc.lock().unwrap(); + if let Some(cached) = cache.get_transpiled(&path_str, file_mtime) { + cached + } else { + // Not in cache, do full transpilation + drop(cache); // Release lock before parsing + + let ast = self.parser.parse(&content)?; + let nickel = self.transpiler.transpile(&ast)?; + + // Store in cache + let mut cache_mut = cache_arc.lock().unwrap(); + // Ignore cache errors - we still have the nickel code to use + cache_mut + .put_transpiled(&path_str, file_mtime, &nickel) + .ok(); + nickel + } + } else { + // No cache, do full transpilation + let ast = self.parser.parse(&content)?; + self.transpiler.transpile(&ast)? + }; + + // Evaluate Nickel code to get AgentDefinition + self.evaluator.evaluate(&nickel_code) + } + + /// Execute agent + /// + /// Delegates to AgentExecutor for actual execution with LLM. + /// Returns ExecutionResult with output, validation status, and metadata. + pub async fn execute( + &self, + agent: &AgentDefinition, + inputs: std::collections::HashMap, + ) -> Result { + self.executor.execute(agent, inputs).await + } + + /// Execute agent with streaming output + /// + /// The callback is invoked for each chunk of output as it arrives from the LLM. + /// Useful for real-time display in CLI or web interfaces. + pub async fn execute_streaming( + &self, + agent: &AgentDefinition, + inputs: std::collections::HashMap, + on_chunk: F, + ) -> Result + where + F: FnMut(&str), + { + self.executor + .execute_streaming(agent, inputs, on_chunk) + .await + } + + /// Load and execute in one call + /// + /// Convenience method that combines load() and execute(). + pub async fn load_and_execute( + &self, + path: &std::path::Path, + inputs: std::collections::HashMap, + ) -> Result { + let agent = self.load(path).await?; + self.execute(&agent, inputs).await + } + + /// Load and execute with streaming + /// + /// Convenience method that combines load() and execute_streaming(). + pub async fn load_and_execute_streaming( + &self, + path: &std::path::Path, + inputs: std::collections::HashMap, + on_chunk: F, + ) -> Result + where + F: FnMut(&str), + { + let agent = self.load(path).await?; + self.execute_streaming(&agent, inputs, on_chunk).await + } +} + +impl Default for AgentLoader { + fn default() -> Self { + Self::new() + } +} diff --git a/crates/typedialog-agent/typedialog-ag-core/src/llm/claude.rs b/crates/typedialog-agent/typedialog-ag-core/src/llm/claude.rs new file mode 100644 index 0000000..62b0444 --- /dev/null +++ b/crates/typedialog-agent/typedialog-ag-core/src/llm/claude.rs @@ -0,0 +1,517 @@ +//! Anthropic Claude API client implementation + +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use std::env; + +use super::provider::{ + LlmProvider, LlmRequest, LlmResponse, LlmStream, MessageRole, StreamChunk, StreamMetadata, + TokenUsage, +}; +use crate::error::{Error, Result}; +use futures::stream::StreamExt; + +const CLAUDE_API_URL: &str = "https://api.anthropic.com/v1/messages"; +const CLAUDE_API_VERSION: &str = "2023-06-01"; + +/// Claude API client +pub struct ClaudeProvider { + api_key: String, + client: reqwest::Client, +} + +/// Claude API request format +#[derive(Debug, Serialize)] +struct ClaudeRequest { + model: String, + messages: Vec, + max_tokens: usize, + #[serde(skip_serializing_if = "Option::is_none")] + temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + system: Option, + #[serde(skip_serializing_if = "Option::is_none")] + stream: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +struct ClaudeMessage { + role: String, + content: String, +} + +/// Claude API response format +#[derive(Debug, Deserialize)] +struct ClaudeResponse { + #[allow(dead_code)] + id: String, + #[allow(dead_code)] + #[serde(rename = "type")] + response_type: String, + #[allow(dead_code)] + role: String, + content: Vec, + model: String, + #[allow(dead_code)] + stop_reason: Option, + usage: ClaudeUsage, +} + +#[derive(Debug, Deserialize)] +struct ClaudeContent { + #[serde(rename = "type")] + content_type: String, + text: String, +} + +#[derive(Debug, Deserialize)] +struct ClaudeUsage { + input_tokens: usize, + output_tokens: usize, +} + +impl ClaudeProvider { + /// Create a new Claude provider + pub fn new() -> Result { + let api_key = env::var("ANTHROPIC_API_KEY").map_err(|_| { + Error::execution( + "ANTHROPIC_API_KEY environment variable not set", + "Set ANTHROPIC_API_KEY to use Claude models", + ) + })?; + + let client = reqwest::Client::new(); + + Ok(Self { api_key, client }) + } + + /// Create a provider with explicit API key + pub fn with_api_key(api_key: String) -> Self { + let client = reqwest::Client::new(); + Self { api_key, client } + } +} + +#[async_trait] +impl LlmProvider for ClaudeProvider { + async fn complete(&self, request: LlmRequest) -> Result { + // Convert generic messages to Claude format + let messages: Vec = request + .messages + .into_iter() + .filter(|m| m.role != MessageRole::System) // System handled separately + .map(|m| ClaudeMessage { + role: match m.role { + MessageRole::User => "user".to_string(), + MessageRole::Assistant => "assistant".to_string(), + MessageRole::System => "user".to_string(), // Fallback + }, + content: m.content, + }) + .collect(); + + // Build Claude request + let claude_request = ClaudeRequest { + model: request.model.clone(), + messages, + max_tokens: request.max_tokens.unwrap_or(4096), + temperature: request.temperature, + system: request.system, + stream: None, + }; + + // Make API call + let response = self + .client + .post(CLAUDE_API_URL) + .header("x-api-key", &self.api_key) + .header("anthropic-version", CLAUDE_API_VERSION) + .header("content-type", "application/json") + .json(&claude_request) + .send() + .await + .map_err(|e| Error::execution("Failed to call Claude API", e.to_string()))?; + + // Check for HTTP errors + if !response.status().is_success() { + let status = response.status(); + let error_text = response + .text() + .await + .unwrap_or_else(|_| "Unknown error".to_string()); + return Err(Error::execution( + format!("Claude API error: {}", status), + error_text, + )); + } + + // Parse response + let claude_response: ClaudeResponse = response + .json() + .await + .map_err(|e| Error::execution("Failed to parse Claude API response", e.to_string()))?; + + // Extract text content + let content = claude_response + .content + .into_iter() + .filter_map(|c| { + if c.content_type == "text" { + Some(c.text) + } else { + None + } + }) + .collect::>() + .join("\n"); + + Ok(LlmResponse { + content, + model: claude_response.model, + usage: Some(TokenUsage { + input_tokens: claude_response.usage.input_tokens, + output_tokens: claude_response.usage.output_tokens, + total_tokens: claude_response.usage.input_tokens + + claude_response.usage.output_tokens, + }), + }) + } + + async fn stream(&self, request: LlmRequest) -> Result { + use futures::stream; + + // Convert generic messages to Claude format + let messages: Vec = request + .messages + .into_iter() + .filter(|m| m.role != MessageRole::System) + .map(|m| ClaudeMessage { + role: match m.role { + MessageRole::User => "user".to_string(), + MessageRole::Assistant => "assistant".to_string(), + MessageRole::System => "user".to_string(), + }, + content: m.content, + }) + .collect(); + + // Build Claude streaming request + let claude_request = ClaudeRequest { + model: request.model.clone(), + messages, + max_tokens: request.max_tokens.unwrap_or(4096), + temperature: request.temperature, + system: request.system, + stream: Some(true), + }; + + // Make streaming API call + let response = self + .client + .post(CLAUDE_API_URL) + .header("x-api-key", &self.api_key) + .header("anthropic-version", CLAUDE_API_VERSION) + .header("content-type", "application/json") + .json(&claude_request) + .send() + .await + .map_err(|e| Error::execution("Failed to call Claude API", e.to_string()))?; + + // Check for HTTP errors + if !response.status().is_success() { + let status = response.status(); + let error_text = response + .text() + .await + .unwrap_or_else(|_| "Unknown error".to_string()); + return Err(Error::execution( + format!("Claude API error: {}", status), + error_text, + )); + } + + // Convert response to SSE stream + use futures::TryStreamExt; + let byte_stream = response.bytes_stream(); + let model = request.model.clone(); + + let sse_stream = byte_stream + .map_err(|e| Error::execution("Stream error", e.to_string())) + .map(move |chunk_result| { + match chunk_result { + Ok(bytes) => { + // Parse SSE events + let text = String::from_utf8_lossy(&bytes); + parse_sse_events(&text, &model) + } + Err(e) => vec![Err(e)], + } + }) + .flat_map(stream::iter); + + Ok(Box::pin(sse_stream)) + } + + fn name(&self) -> &str { + "Claude" + } +} + +/// Parse SSE events from Claude streaming response +fn parse_sse_events(text: &str, model: &str) -> Vec> { + let mut chunks = Vec::new(); + + for line in text.lines() { + if let Some(data) = line.strip_prefix("data: ") { + // Parse JSON event + if let Ok(event) = serde_json::from_str::(data) { + match event.get("type").and_then(|t| t.as_str()) { + Some("content_block_delta") => { + // Extract text delta + if let Some(delta) = event.get("delta") { + if let Some(text) = delta.get("text").and_then(|t| t.as_str()) { + chunks.push(Ok(StreamChunk::Content(text.to_string()))); + } + } + } + Some("message_stop") => { + // Stream completed - extract usage if available + let usage = event.get("usage").and_then(|u| { + Some(TokenUsage { + input_tokens: u.get("input_tokens")?.as_u64()? as usize, + output_tokens: u.get("output_tokens")?.as_u64()? as usize, + total_tokens: (u.get("input_tokens")?.as_u64()? + + u.get("output_tokens")?.as_u64()?) + as usize, + }) + }); + + chunks.push(Ok(StreamChunk::Done(StreamMetadata { + model: model.to_string(), + usage, + }))); + } + Some("message_delta") => { + // Extract usage from delta + if let Some(usage_obj) = event.get("usage") { + let usage = Some(TokenUsage { + input_tokens: usage_obj + .get("input_tokens") + .and_then(|v| v.as_u64()) + .unwrap_or(0) + as usize, + output_tokens: usage_obj + .get("output_tokens") + .and_then(|v| v.as_u64()) + .unwrap_or(0) + as usize, + total_tokens: (usage_obj + .get("input_tokens") + .and_then(|v| v.as_u64()) + .unwrap_or(0) + + usage_obj + .get("output_tokens") + .and_then(|v| v.as_u64()) + .unwrap_or(0)) + as usize, + }); + + chunks.push(Ok(StreamChunk::Done(StreamMetadata { + model: model.to_string(), + usage, + }))); + } + } + Some("error") => { + let error_msg = event + .get("error") + .and_then(|e| e.get("message")) + .and_then(|m| m.as_str()) + .unwrap_or("Unknown error"); + chunks.push(Ok(StreamChunk::Error(error_msg.to_string()))); + } + _ => {} + } + } + } + } + + chunks +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::llm::LlmMessage; + + #[test] + fn test_create_provider_without_api_key() { + // Should fail if ANTHROPIC_API_KEY not set + let original_key = env::var("ANTHROPIC_API_KEY").ok(); + env::remove_var("ANTHROPIC_API_KEY"); + + let result = ClaudeProvider::new(); + assert!(result.is_err()); + + // Restore original key if it existed + if let Some(key) = original_key { + env::set_var("ANTHROPIC_API_KEY", key); + } + } + + #[test] + fn test_create_provider_with_explicit_key() { + let provider = ClaudeProvider::with_api_key("test-key".to_string()); + assert_eq!(provider.api_key, "test-key"); + assert_eq!(provider.name(), "Claude"); + } + + #[test] + fn test_parse_sse_content_block_delta() { + let sse_data = + r#"data: {"type":"content_block_delta","delta":{"type":"text_delta","text":"Hello"}}"#; + let model = "claude-3-5-haiku-20241022"; + + let chunks = parse_sse_events(sse_data, model); + + assert_eq!(chunks.len(), 1); + match &chunks[0] { + Ok(StreamChunk::Content(text)) => assert_eq!(text, "Hello"), + _ => panic!("Expected Content chunk"), + } + } + + #[test] + fn test_parse_sse_message_stop() { + let sse_data = r#"data: {"type":"message_stop"}"#; + let model = "claude-3-5-haiku-20241022"; + + let chunks = parse_sse_events(sse_data, model); + + assert_eq!(chunks.len(), 1); + match &chunks[0] { + Ok(StreamChunk::Done(metadata)) => { + assert_eq!(metadata.model, model); + assert!(metadata.usage.is_none()); + } + _ => panic!("Expected Done chunk"), + } + } + + #[test] + fn test_parse_sse_message_delta_with_usage() { + let sse_data = + r#"data: {"type":"message_delta","usage":{"input_tokens":100,"output_tokens":50}}"#; + let model = "claude-3-5-haiku-20241022"; + + let chunks = parse_sse_events(sse_data, model); + + assert_eq!(chunks.len(), 1); + match &chunks[0] { + Ok(StreamChunk::Done(metadata)) => { + assert_eq!(metadata.model, model); + let usage = metadata.usage.as_ref().unwrap(); + assert_eq!(usage.input_tokens, 100); + assert_eq!(usage.output_tokens, 50); + assert_eq!(usage.total_tokens, 150); + } + _ => panic!("Expected Done chunk with usage"), + } + } + + #[test] + fn test_parse_sse_error() { + let sse_data = r#"data: {"type":"error","error":{"message":"Rate limit exceeded"}}"#; + let model = "claude-3-5-haiku-20241022"; + + let chunks = parse_sse_events(sse_data, model); + + assert_eq!(chunks.len(), 1); + match &chunks[0] { + Ok(StreamChunk::Error(msg)) => assert_eq!(msg, "Rate limit exceeded"), + _ => panic!("Expected Error chunk"), + } + } + + #[test] + fn test_parse_sse_multiple_events() { + let sse_data = r#"data: {"type":"content_block_delta","delta":{"type":"text_delta","text":"Hello"}} +data: {"type":"content_block_delta","delta":{"type":"text_delta","text":" world"}} +data: {"type":"message_delta","usage":{"input_tokens":10,"output_tokens":5}}"#; + let model = "claude-3-5-haiku-20241022"; + + let chunks = parse_sse_events(sse_data, model); + + assert_eq!(chunks.len(), 3); + + match &chunks[0] { + Ok(StreamChunk::Content(text)) => assert_eq!(text, "Hello"), + _ => panic!("Expected first Content chunk"), + } + + match &chunks[1] { + Ok(StreamChunk::Content(text)) => assert_eq!(text, " world"), + _ => panic!("Expected second Content chunk"), + } + + match &chunks[2] { + Ok(StreamChunk::Done(metadata)) => { + assert!(metadata.usage.is_some()); + } + _ => panic!("Expected Done chunk"), + } + } + + #[test] + fn test_parse_sse_ignores_unknown_events() { + let sse_data = r#"data: {"type":"unknown_event"} +data: {"type":"content_block_delta","delta":{"type":"text_delta","text":"test"}}"#; + let model = "claude-3-5-haiku-20241022"; + + let chunks = parse_sse_events(sse_data, model); + + // Should only parse the content delta, ignore unknown + assert_eq!(chunks.len(), 1); + match &chunks[0] { + Ok(StreamChunk::Content(text)) => assert_eq!(text, "test"), + _ => panic!("Expected Content chunk"), + } + } + + #[test] + fn test_parse_sse_empty_input() { + let sse_data = ""; + let model = "claude-3-5-haiku-20241022"; + + let chunks = parse_sse_events(sse_data, model); + + assert_eq!(chunks.len(), 0); + } + + #[tokio::test] + #[ignore] // Only run with real API key + async fn test_claude_api_call() { + let provider = ClaudeProvider::new().expect("ANTHROPIC_API_KEY must be set"); + + let request = LlmRequest { + model: "claude-3-5-haiku-20241022".to_string(), + messages: vec![LlmMessage { + role: MessageRole::User, + content: "Say hello in exactly 3 words.".to_string(), + }], + max_tokens: Some(50), + temperature: Some(0.7), + system: None, + }; + + let response = provider.complete(request).await; + assert!(response.is_ok()); + + let response = response.unwrap(); + assert!(!response.content.is_empty()); + assert!(response.usage.is_some()); + + println!("Response: {}", response.content); + println!("Usage: {:?}", response.usage); + } +} diff --git a/crates/typedialog-agent/typedialog-ag-core/src/llm/gemini.rs b/crates/typedialog-agent/typedialog-ag-core/src/llm/gemini.rs new file mode 100644 index 0000000..70cf057 --- /dev/null +++ b/crates/typedialog-agent/typedialog-ag-core/src/llm/gemini.rs @@ -0,0 +1,555 @@ +//! Google Gemini API client implementation + +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use std::env; + +use super::provider::{ + LlmProvider, LlmRequest, LlmResponse, LlmStream, MessageRole, StreamChunk, StreamMetadata, + TokenUsage, +}; +use crate::error::{Error, Result}; +use futures::stream::StreamExt; + +const GEMINI_API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta/models"; + +/// Gemini API client +pub struct GeminiProvider { + api_key: String, + client: reqwest::Client, +} + +/// Gemini API request format +#[derive(Debug, Serialize)] +struct GeminiRequest { + contents: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(rename = "generationConfig")] + generation_config: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(rename = "systemInstruction")] + system_instruction: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +struct GeminiContent { + role: String, + parts: Vec, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +struct GeminiPart { + text: String, +} + +#[derive(Debug, Serialize)] +struct GeminiGenerationConfig { + #[serde(skip_serializing_if = "Option::is_none")] + temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(rename = "maxOutputTokens")] + max_output_tokens: Option, +} + +/// Gemini API response format +#[derive(Debug, Deserialize)] +struct GeminiResponse { + candidates: Vec, + #[serde(rename = "usageMetadata")] + usage_metadata: Option, +} + +#[derive(Debug, Deserialize)] +struct GeminiCandidate { + content: GeminiContent, + #[allow(dead_code)] + #[serde(rename = "finishReason")] + finish_reason: Option, +} + +#[derive(Debug, Deserialize)] +struct GeminiUsageMetadata { + #[serde(rename = "promptTokenCount")] + prompt_token_count: usize, + #[serde(rename = "candidatesTokenCount")] + candidates_token_count: usize, + #[serde(rename = "totalTokenCount")] + total_token_count: usize, +} + +impl GeminiProvider { + /// Create a new Gemini provider + pub fn new() -> Result { + let api_key = env::var("GEMINI_API_KEY") + .or_else(|_| env::var("GOOGLE_API_KEY")) + .map_err(|_| { + Error::execution( + "GEMINI_API_KEY or GOOGLE_API_KEY environment variable not set", + "Set GEMINI_API_KEY or GOOGLE_API_KEY to use Gemini models", + ) + })?; + + let client = reqwest::Client::new(); + + Ok(Self { api_key, client }) + } + + /// Create a provider with explicit API key + pub fn with_api_key(api_key: String) -> Self { + let client = reqwest::Client::new(); + Self { api_key, client } + } + + /// Convert generic messages to Gemini format + fn convert_messages(&self, messages: Vec) -> Vec { + messages + .into_iter() + .filter(|m| m.role != MessageRole::System) + .map(|m| GeminiContent { + role: match m.role { + MessageRole::User => "user".to_string(), + MessageRole::Assistant => "model".to_string(), // Gemini uses "model" not "assistant" + MessageRole::System => "user".to_string(), // Fallback + }, + parts: vec![GeminiPart { text: m.content }], + }) + .collect() + } + + /// Build API URL for the given model + fn build_url(&self, model: &str, streaming: bool) -> String { + let endpoint = if streaming { + "streamGenerateContent" + } else { + "generateContent" + }; + format!( + "{}/{}:{}?key={}", + GEMINI_API_BASE, model, endpoint, self.api_key + ) + } +} + +#[async_trait] +impl LlmProvider for GeminiProvider { + async fn complete(&self, request: LlmRequest) -> Result { + // Convert messages + let contents = self.convert_messages(request.messages); + + // Build generation config + let generation_config = if request.max_tokens.is_some() || request.temperature.is_some() { + Some(GeminiGenerationConfig { + temperature: request.temperature, + max_output_tokens: request.max_tokens, + }) + } else { + None + }; + + // Build system instruction + let system_instruction = request.system.map(|content| GeminiContent { + role: "user".to_string(), + parts: vec![GeminiPart { text: content }], + }); + + // Build Gemini request + let gemini_request = GeminiRequest { + contents, + generation_config, + system_instruction, + }; + + // Make API call + let url = self.build_url(&request.model, false); + let response = self + .client + .post(&url) + .header("Content-Type", "application/json") + .json(&gemini_request) + .send() + .await + .map_err(|e| Error::execution("Failed to call Gemini API", e.to_string()))?; + + // Check for HTTP errors + if !response.status().is_success() { + let status = response.status(); + let error_text = response + .text() + .await + .unwrap_or_else(|_| "Unknown error".to_string()); + return Err(Error::execution( + format!("Gemini API error: {}", status), + error_text, + )); + } + + // Parse response + let gemini_response: GeminiResponse = response + .json() + .await + .map_err(|e| Error::execution("Failed to parse Gemini API response", e.to_string()))?; + + // Extract content from first candidate + let content = gemini_response + .candidates + .first() + .and_then(|candidate| { + candidate + .content + .parts + .first() + .map(|part| part.text.clone()) + }) + .unwrap_or_default(); + + // Extract usage metadata + let usage = gemini_response.usage_metadata.map(|u| TokenUsage { + input_tokens: u.prompt_token_count, + output_tokens: u.candidates_token_count, + total_tokens: u.total_token_count, + }); + + Ok(LlmResponse { + content, + model: request.model, + usage, + }) + } + + async fn stream(&self, request: LlmRequest) -> Result { + use futures::stream; + + // Convert messages + let contents = self.convert_messages(request.messages.clone()); + + // Build generation config + let generation_config = if request.max_tokens.is_some() || request.temperature.is_some() { + Some(GeminiGenerationConfig { + temperature: request.temperature, + max_output_tokens: request.max_tokens, + }) + } else { + None + }; + + // Build system instruction + let system_instruction = request.system.map(|content| GeminiContent { + role: "user".to_string(), + parts: vec![GeminiPart { text: content }], + }); + + // Build Gemini streaming request + let gemini_request = GeminiRequest { + contents, + generation_config, + system_instruction, + }; + + // Make streaming API call + let url = self.build_url(&request.model, true); + let response = self + .client + .post(&url) + .header("Content-Type", "application/json") + .json(&gemini_request) + .send() + .await + .map_err(|e| Error::execution("Failed to call Gemini API", e.to_string()))?; + + // Check for HTTP errors + if !response.status().is_success() { + let status = response.status(); + let error_text = response + .text() + .await + .unwrap_or_else(|_| "Unknown error".to_string()); + return Err(Error::execution( + format!("Gemini API error: {}", status), + error_text, + )); + } + + // Convert response to SSE stream + use futures::TryStreamExt; + let byte_stream = response.bytes_stream(); + let model = request.model.clone(); + + let sse_stream = byte_stream + .map_err(|e| Error::execution("Stream error", e.to_string())) + .map(move |chunk_result| { + match chunk_result { + Ok(bytes) => { + // Parse SSE events + let text = String::from_utf8_lossy(&bytes); + parse_sse_events(&text, &model) + } + Err(e) => vec![Err(e)], + } + }) + .flat_map(stream::iter); + + Ok(Box::pin(sse_stream)) + } + + fn name(&self) -> &str { + "Gemini" + } +} + +/// Parse SSE events from Gemini streaming response +fn parse_sse_events(text: &str, model: &str) -> Vec> { + let mut chunks = Vec::new(); + + for line in text.lines() { + // Gemini streams JSON objects separated by newlines (not SSE format with "data: " prefix) + // Each line is a complete JSON response + if line.trim().is_empty() || line.trim().starts_with('[') { + continue; + } + + // Parse JSON event + if let Ok(event) = serde_json::from_str::(line) { + // Extract content from candidates + if let Some(candidates) = event.get("candidates").and_then(|c| c.as_array()) { + if let Some(candidate) = candidates.first() { + if let Some(content) = candidate.get("content") { + if let Some(parts) = content.get("parts").and_then(|p| p.as_array()) { + if let Some(part) = parts.first() { + if let Some(text) = part.get("text").and_then(|t| t.as_str()) { + chunks.push(Ok(StreamChunk::Content(text.to_string()))); + } + } + } + } + + // Check for finish reason (indicates stream end) + if let Some(_finish_reason) = candidate.get("finishReason") { + // Extract usage metadata if available + let usage = event.get("usageMetadata").and_then(|u| { + Some(TokenUsage { + input_tokens: u.get("promptTokenCount")?.as_u64()? as usize, + output_tokens: u.get("candidatesTokenCount")?.as_u64()? as usize, + total_tokens: u.get("totalTokenCount")?.as_u64()? as usize, + }) + }); + + chunks.push(Ok(StreamChunk::Done(StreamMetadata { + model: model.to_string(), + usage, + }))); + } + } + } + + // Check for errors + if let Some(error) = event.get("error") { + let error_msg = error + .get("message") + .and_then(|m| m.as_str()) + .unwrap_or("Unknown error"); + chunks.push(Ok(StreamChunk::Error(error_msg.to_string()))); + } + } + } + + chunks +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::llm::LlmMessage; + + #[test] + fn test_create_provider_without_api_key() { + // Should fail if GEMINI_API_KEY not set + let original_gemini = env::var("GEMINI_API_KEY").ok(); + let original_google = env::var("GOOGLE_API_KEY").ok(); + env::remove_var("GEMINI_API_KEY"); + env::remove_var("GOOGLE_API_KEY"); + + let result = GeminiProvider::new(); + assert!(result.is_err()); + + // Restore original keys if they existed + if let Some(key) = original_gemini { + env::set_var("GEMINI_API_KEY", key); + } + if let Some(key) = original_google { + env::set_var("GOOGLE_API_KEY", key); + } + } + + #[test] + fn test_create_provider_with_explicit_key() { + let provider = GeminiProvider::with_api_key("test-key".to_string()); + assert_eq!(provider.api_key, "test-key"); + assert_eq!(provider.name(), "Gemini"); + } + + #[test] + fn test_convert_messages() { + let provider = GeminiProvider::with_api_key("test-key".to_string()); + + let messages = vec![ + LlmMessage { + role: MessageRole::User, + content: "Hello".to_string(), + }, + LlmMessage { + role: MessageRole::Assistant, + content: "Hi there".to_string(), + }, + ]; + + let gemini_messages = provider.convert_messages(messages); + + assert_eq!(gemini_messages.len(), 2); + assert_eq!(gemini_messages[0].role, "user"); + assert_eq!(gemini_messages[0].parts[0].text, "Hello"); + assert_eq!(gemini_messages[1].role, "model"); + assert_eq!(gemini_messages[1].parts[0].text, "Hi there"); + } + + #[test] + fn test_parse_sse_content() { + let json_line = + r#"{"candidates":[{"content":{"parts":[{"text":"Hello"}],"role":"model"}}]}"#; + let model = "gemini-2.0-flash-exp"; + + let chunks = parse_sse_events(json_line, model); + + assert_eq!(chunks.len(), 1); + match &chunks[0] { + Ok(StreamChunk::Content(text)) => assert_eq!(text, "Hello"), + _ => panic!("Expected Content chunk"), + } + } + + #[test] + fn test_parse_sse_with_finish_reason() { + let json_line = r#"{"candidates":[{"content":{"parts":[{"text":"Done"}],"role":"model"},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":20,"totalTokenCount":30}}"#; + let model = "gemini-2.0-flash-exp"; + + let chunks = parse_sse_events(json_line, model); + + assert_eq!(chunks.len(), 2); + + match &chunks[0] { + Ok(StreamChunk::Content(text)) => assert_eq!(text, "Done"), + _ => panic!("Expected Content chunk"), + } + + match &chunks[1] { + Ok(StreamChunk::Done(metadata)) => { + assert_eq!(metadata.model, model); + let usage = metadata.usage.as_ref().unwrap(); + assert_eq!(usage.input_tokens, 10); + assert_eq!(usage.output_tokens, 20); + assert_eq!(usage.total_tokens, 30); + } + _ => panic!("Expected Done chunk"), + } + } + + #[test] + fn test_parse_sse_error() { + let json_line = r#"{"error":{"message":"API key invalid"}}"#; + let model = "gemini-2.0-flash-exp"; + + let chunks = parse_sse_events(json_line, model); + + assert_eq!(chunks.len(), 1); + match &chunks[0] { + Ok(StreamChunk::Error(msg)) => assert_eq!(msg, "API key invalid"), + _ => panic!("Expected Error chunk"), + } + } + + #[test] + fn test_parse_sse_multiple_chunks() { + let json_lines = r#"{"candidates":[{"content":{"parts":[{"text":"Hello"}],"role":"model"}}]} +{"candidates":[{"content":{"parts":[{"text":" world"}],"role":"model"}}]} +{"candidates":[{"content":{"parts":[{"text":"!"}],"role":"model"},"finishReason":"STOP"}]}"#; + let model = "gemini-2.0-flash-exp"; + + let chunks = parse_sse_events(json_lines, model); + + assert_eq!(chunks.len(), 4); + + match &chunks[0] { + Ok(StreamChunk::Content(text)) => assert_eq!(text, "Hello"), + _ => panic!("Expected first Content chunk"), + } + + match &chunks[1] { + Ok(StreamChunk::Content(text)) => assert_eq!(text, " world"), + _ => panic!("Expected second Content chunk"), + } + + match &chunks[2] { + Ok(StreamChunk::Content(text)) => assert_eq!(text, "!"), + _ => panic!("Expected third Content chunk"), + } + + match &chunks[3] { + Ok(StreamChunk::Done(_)) => {} + _ => panic!("Expected Done chunk"), + } + } + + #[test] + fn test_parse_sse_empty_input() { + let json_lines = ""; + let model = "gemini-2.0-flash-exp"; + + let chunks = parse_sse_events(json_lines, model); + + assert_eq!(chunks.len(), 0); + } + + #[test] + fn test_build_url_non_streaming() { + let provider = GeminiProvider::with_api_key("test-key-123".to_string()); + let url = provider.build_url("gemini-2.0-flash-exp", false); + + assert!(url.contains("gemini-2.0-flash-exp:generateContent")); + assert!(url.contains("key=test-key-123")); + } + + #[test] + fn test_build_url_streaming() { + let provider = GeminiProvider::with_api_key("test-key-456".to_string()); + let url = provider.build_url("gemini-2.0-flash-exp", true); + + assert!(url.contains("gemini-2.0-flash-exp:streamGenerateContent")); + assert!(url.contains("key=test-key-456")); + } + + #[tokio::test] + #[ignore] // Only run with real API key + async fn test_gemini_api_call() { + let provider = GeminiProvider::new().expect("GEMINI_API_KEY must be set"); + + let request = LlmRequest { + model: "gemini-2.0-flash-exp".to_string(), + messages: vec![LlmMessage { + role: MessageRole::User, + content: "Say hello in exactly 3 words.".to_string(), + }], + max_tokens: Some(50), + temperature: Some(0.7), + system: None, + }; + + let response = provider.complete(request).await; + assert!(response.is_ok()); + + let response = response.unwrap(); + assert!(!response.content.is_empty()); + assert!(response.usage.is_some()); + + println!("Response: {}", response.content); + println!("Usage: {:?}", response.usage); + } +} diff --git a/crates/typedialog-agent/typedialog-ag-core/src/llm/mod.rs b/crates/typedialog-agent/typedialog-ag-core/src/llm/mod.rs new file mode 100644 index 0000000..8d59a71 --- /dev/null +++ b/crates/typedialog-agent/typedialog-ag-core/src/llm/mod.rs @@ -0,0 +1,70 @@ +//! LLM provider abstraction and implementations + +pub mod claude; +pub mod gemini; +pub mod ollama; +pub mod openai; +pub mod provider; + +pub use claude::ClaudeProvider; +pub use gemini::GeminiProvider; +pub use ollama::OllamaProvider; +pub use openai::OpenAIProvider; +pub use provider::{ + LlmMessage, LlmProvider, LlmRequest, LlmResponse, LlmStream, MessageRole, StreamChunk, + StreamMetadata, TokenUsage, +}; + +use crate::error::{Error, Result}; + +/// Create an LLM provider based on model name +pub fn create_provider(model: &str) -> Result> { + // Determine provider from model name + if model.starts_with("claude") || model.starts_with("anthropic") { + Ok(Box::new(ClaudeProvider::new()?)) + } else if model.starts_with("gpt") + || model.starts_with("o1") + || model.starts_with("o3") + || model.starts_with("o4") + { + Ok(Box::new(OpenAIProvider::new()?)) + } else if model.starts_with("gemini") { + Ok(Box::new(GeminiProvider::new()?)) + } else if is_ollama_model(model) { + Ok(Box::new(OllamaProvider::new()?)) + } else { + Err(Error::execution( + "Unknown model provider", + format!("Model: {}. Supported: claude-*, gpt-*, o1-*, o3-*, o4-*, gemini-*, llama*, mistral*, phi*, codellama*, etc.", model), + )) + } +} + +/// Check if a model name is a known Ollama model +fn is_ollama_model(model: &str) -> bool { + // Common Ollama model prefixes + let ollama_prefixes = [ + "llama", + "mistral", + "phi", + "codellama", + "mixtral", + "vicuna", + "orca", + "wizardlm", + "falcon", + "starcoder", + "deepseek", + "qwen", + "yi", + "solar", + "dolphin", + "openhermes", + "neural", + "zephyr", + ]; + + ollama_prefixes + .iter() + .any(|prefix| model.starts_with(prefix)) +} diff --git a/crates/typedialog-agent/typedialog-ag-core/src/llm/ollama.rs b/crates/typedialog-agent/typedialog-ag-core/src/llm/ollama.rs new file mode 100644 index 0000000..cb9dc29 --- /dev/null +++ b/crates/typedialog-agent/typedialog-ag-core/src/llm/ollama.rs @@ -0,0 +1,504 @@ +//! Ollama API client implementation +//! +//! Ollama provides local LLM execution with an OpenAI-compatible API. +//! Default endpoint: http://localhost:11434 + +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use std::env; + +use super::provider::{ + LlmProvider, LlmRequest, LlmResponse, LlmStream, MessageRole, StreamChunk, StreamMetadata, + TokenUsage, +}; +use crate::error::{Error, Result}; +use futures::stream::StreamExt; + +const DEFAULT_OLLAMA_URL: &str = "http://localhost:11434"; + +/// Ollama API client +pub struct OllamaProvider { + base_url: String, + client: reqwest::Client, +} + +/// Ollama API request format (OpenAI-compatible) +#[derive(Debug, Serialize)] +struct OllamaRequest { + model: String, + messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + options: Option, + #[serde(skip_serializing_if = "Option::is_none")] + stream: Option, +} + +#[derive(Debug, Serialize)] +struct OllamaOptions { + #[serde(skip_serializing_if = "Option::is_none")] + temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + num_predict: Option, // Ollama's equivalent of max_tokens +} + +#[derive(Debug, Serialize, Deserialize)] +struct OllamaMessage { + role: String, + content: String, +} + +/// Ollama API response format +#[derive(Debug, Deserialize)] +struct OllamaResponse { + model: String, + #[allow(dead_code)] + created_at: String, + message: OllamaMessage, + #[allow(dead_code)] + done: bool, + #[serde(default)] + prompt_eval_count: Option, + #[serde(default)] + eval_count: Option, +} + +impl OllamaProvider { + /// Create a new Ollama provider with default localhost URL + pub fn new() -> Result { + let base_url = + env::var("OLLAMA_BASE_URL").unwrap_or_else(|_| DEFAULT_OLLAMA_URL.to_string()); + + let client = reqwest::Client::new(); + + Ok(Self { base_url, client }) + } + + /// Create a provider with custom base URL + pub fn with_base_url(base_url: String) -> Self { + let client = reqwest::Client::new(); + Self { base_url, client } + } + + /// Build API URL for chat endpoint + fn build_url(&self) -> String { + format!("{}/api/chat", self.base_url) + } +} + +#[async_trait] +impl LlmProvider for OllamaProvider { + async fn complete(&self, request: LlmRequest) -> Result { + // Convert generic messages to Ollama format + let mut messages: Vec = request + .messages + .into_iter() + .filter(|m| m.role != MessageRole::System) + .map(|m| OllamaMessage { + role: match m.role { + MessageRole::User => "user".to_string(), + MessageRole::Assistant => "assistant".to_string(), + MessageRole::System => "system".to_string(), + }, + content: m.content, + }) + .collect(); + + // Add system message at the beginning if provided + if let Some(system_prompt) = request.system { + messages.insert( + 0, + OllamaMessage { + role: "system".to_string(), + content: system_prompt, + }, + ); + } + + // Build options + let options = if request.max_tokens.is_some() || request.temperature.is_some() { + Some(OllamaOptions { + temperature: request.temperature, + num_predict: request.max_tokens.map(|t| t as i32), + }) + } else { + None + }; + + // Build Ollama request + let ollama_request = OllamaRequest { + model: request.model.clone(), + messages, + options, + stream: Some(false), + }; + + // Make API call + let url = self.build_url(); + let response = self + .client + .post(&url) + .header("Content-Type", "application/json") + .json(&ollama_request) + .send() + .await + .map_err(|e| { + Error::execution( + "Failed to call Ollama API - is Ollama running?", + format!("{} (URL: {})", e, url), + ) + })?; + + // Check for HTTP errors + if !response.status().is_success() { + let status = response.status(); + let error_text = response + .text() + .await + .unwrap_or_else(|_| "Unknown error".to_string()); + return Err(Error::execution( + format!("Ollama API error: {}", status), + error_text, + )); + } + + // Parse response + let ollama_response: OllamaResponse = response + .json() + .await + .map_err(|e| Error::execution("Failed to parse Ollama API response", e.to_string()))?; + + // Calculate usage + let usage = if ollama_response.prompt_eval_count.is_some() + || ollama_response.eval_count.is_some() + { + let input_tokens = ollama_response.prompt_eval_count.unwrap_or(0) as usize; + let output_tokens = ollama_response.eval_count.unwrap_or(0) as usize; + Some(TokenUsage { + input_tokens, + output_tokens, + total_tokens: input_tokens + output_tokens, + }) + } else { + None + }; + + Ok(LlmResponse { + content: ollama_response.message.content, + model: ollama_response.model, + usage, + }) + } + + async fn stream(&self, request: LlmRequest) -> Result { + use futures::stream; + + // Convert generic messages to Ollama format + let mut messages: Vec = request + .messages + .into_iter() + .filter(|m| m.role != MessageRole::System) + .map(|m| OllamaMessage { + role: match m.role { + MessageRole::User => "user".to_string(), + MessageRole::Assistant => "assistant".to_string(), + MessageRole::System => "system".to_string(), + }, + content: m.content, + }) + .collect(); + + // Add system message at the beginning if provided + if let Some(system_prompt) = request.system { + messages.insert( + 0, + OllamaMessage { + role: "system".to_string(), + content: system_prompt, + }, + ); + } + + // Build options + let options = if request.max_tokens.is_some() || request.temperature.is_some() { + Some(OllamaOptions { + temperature: request.temperature, + num_predict: request.max_tokens.map(|t| t as i32), + }) + } else { + None + }; + + // Build Ollama streaming request + let ollama_request = OllamaRequest { + model: request.model.clone(), + messages, + options, + stream: Some(true), + }; + + // Make streaming API call + let url = self.build_url(); + let response = self + .client + .post(&url) + .header("Content-Type", "application/json") + .json(&ollama_request) + .send() + .await + .map_err(|e| { + Error::execution( + "Failed to call Ollama API - is Ollama running?", + format!("{} (URL: {})", e, url), + ) + })?; + + // Check for HTTP errors + if !response.status().is_success() { + let status = response.status(); + let error_text = response + .text() + .await + .unwrap_or_else(|_| "Unknown error".to_string()); + return Err(Error::execution( + format!("Ollama API error: {}", status), + error_text, + )); + } + + // Convert response to JSON stream + use futures::TryStreamExt; + let byte_stream = response.bytes_stream(); + let model = request.model.clone(); + + let json_stream = byte_stream + .map_err(|e| Error::execution("Stream error", e.to_string())) + .map(move |chunk_result| { + match chunk_result { + Ok(bytes) => { + // Parse JSON events (newline-delimited like Gemini) + let text = String::from_utf8_lossy(&bytes); + parse_json_events(&text, &model) + } + Err(e) => vec![Err(e)], + } + }) + .flat_map(stream::iter); + + Ok(Box::pin(json_stream)) + } + + fn name(&self) -> &str { + "Ollama" + } +} + +/// Parse JSON events from Ollama streaming response +fn parse_json_events(text: &str, model: &str) -> Vec> { + let mut chunks = Vec::new(); + + for line in text.lines() { + if line.trim().is_empty() { + continue; + } + + // Parse JSON event + if let Ok(event) = serde_json::from_str::(line) { + // Extract message content + if let Some(message) = event.get("message") { + if let Some(content) = message.get("content").and_then(|c| c.as_str()) { + if !content.is_empty() { + chunks.push(Ok(StreamChunk::Content(content.to_string()))); + } + } + } + + // Check for completion (done: true) + if let Some(true) = event.get("done").and_then(|d| d.as_bool()) { + // Extract usage metadata + let usage = Some(TokenUsage { + input_tokens: event + .get("prompt_eval_count") + .and_then(|v| v.as_i64()) + .unwrap_or(0) as usize, + output_tokens: event + .get("eval_count") + .and_then(|v| v.as_i64()) + .unwrap_or(0) as usize, + total_tokens: (event + .get("prompt_eval_count") + .and_then(|v| v.as_i64()) + .unwrap_or(0) + + event + .get("eval_count") + .and_then(|v| v.as_i64()) + .unwrap_or(0)) as usize, + }); + + chunks.push(Ok(StreamChunk::Done(StreamMetadata { + model: model.to_string(), + usage, + }))); + } + + // Check for errors + if let Some(error) = event.get("error") { + let error_msg = error.as_str().unwrap_or("Unknown error"); + chunks.push(Ok(StreamChunk::Error(error_msg.to_string()))); + } + } + } + + chunks +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::llm::LlmMessage; + + #[test] + fn test_create_provider_with_default_url() { + let provider = OllamaProvider::new().expect("Should create provider"); + assert_eq!(provider.base_url, DEFAULT_OLLAMA_URL); + assert_eq!(provider.name(), "Ollama"); + } + + #[test] + fn test_create_provider_with_custom_url() { + let custom_url = "http://custom-host:8080"; + let provider = OllamaProvider::with_base_url(custom_url.to_string()); + assert_eq!(provider.base_url, custom_url); + assert_eq!(provider.name(), "Ollama"); + } + + #[test] + fn test_build_url() { + let provider = OllamaProvider::with_base_url("http://localhost:11434".to_string()); + assert_eq!(provider.build_url(), "http://localhost:11434/api/chat"); + } + + #[test] + fn test_parse_json_content() { + let json_line = r#"{"model":"llama2","created_at":"2024-01-01T00:00:00Z","message":{"role":"assistant","content":"Hello"},"done":false}"#; + let model = "llama2"; + + let chunks = parse_json_events(json_line, model); + + assert_eq!(chunks.len(), 1); + match &chunks[0] { + Ok(StreamChunk::Content(text)) => assert_eq!(text, "Hello"), + _ => panic!("Expected Content chunk"), + } + } + + #[test] + fn test_parse_json_done() { + let json_line = r#"{"model":"llama2","created_at":"2024-01-01T00:00:00Z","message":{"role":"assistant","content":""},"done":true,"prompt_eval_count":10,"eval_count":20}"#; + let model = "llama2"; + + let chunks = parse_json_events(json_line, model); + + assert_eq!(chunks.len(), 1); + match &chunks[0] { + Ok(StreamChunk::Done(metadata)) => { + assert_eq!(metadata.model, model); + let usage = metadata.usage.as_ref().unwrap(); + assert_eq!(usage.input_tokens, 10); + assert_eq!(usage.output_tokens, 20); + assert_eq!(usage.total_tokens, 30); + } + _ => panic!("Expected Done chunk"), + } + } + + #[test] + fn test_parse_json_error() { + let json_line = r#"{"error":"model not found"}"#; + let model = "llama2"; + + let chunks = parse_json_events(json_line, model); + + assert_eq!(chunks.len(), 1); + match &chunks[0] { + Ok(StreamChunk::Error(msg)) => assert_eq!(msg, "model not found"), + _ => panic!("Expected Error chunk"), + } + } + + #[test] + fn test_parse_json_multiple_chunks() { + let json_lines = r#"{"model":"llama2","message":{"role":"assistant","content":"Hello"},"done":false} +{"model":"llama2","message":{"role":"assistant","content":" world"},"done":false} +{"model":"llama2","message":{"role":"assistant","content":""},"done":true,"prompt_eval_count":5,"eval_count":10}"#; + let model = "llama2"; + + let chunks = parse_json_events(json_lines, model); + + assert_eq!(chunks.len(), 3); + + match &chunks[0] { + Ok(StreamChunk::Content(text)) => assert_eq!(text, "Hello"), + _ => panic!("Expected first Content chunk"), + } + + match &chunks[1] { + Ok(StreamChunk::Content(text)) => assert_eq!(text, " world"), + _ => panic!("Expected second Content chunk"), + } + + match &chunks[2] { + Ok(StreamChunk::Done(metadata)) => { + assert!(metadata.usage.is_some()); + } + _ => panic!("Expected Done chunk"), + } + } + + #[test] + fn test_parse_json_empty_content_ignored() { + let json_line = + r#"{"model":"llama2","message":{"role":"assistant","content":""},"done":false}"#; + let model = "llama2"; + + let chunks = parse_json_events(json_line, model); + + // Empty content should not produce a chunk + assert_eq!(chunks.len(), 0); + } + + #[test] + fn test_parse_json_empty_input() { + let json_lines = ""; + let model = "llama2"; + + let chunks = parse_json_events(json_lines, model); + + assert_eq!(chunks.len(), 0); + } + + #[tokio::test] + #[ignore] // Only run when Ollama is running locally + async fn test_ollama_api_call() { + let provider = OllamaProvider::new().expect("Should create provider"); + + let request = LlmRequest { + model: "llama2".to_string(), + messages: vec![LlmMessage { + role: MessageRole::User, + content: "Say hello in exactly 3 words.".to_string(), + }], + max_tokens: Some(50), + temperature: Some(0.7), + system: None, + }; + + let response = provider.complete(request).await; + assert!(response.is_ok()); + + let response = response.unwrap(); + assert!(!response.content.is_empty()); + + println!("Response: {}", response.content); + println!("Usage: {:?}", response.usage); + } +} diff --git a/crates/typedialog-agent/typedialog-ag-core/src/llm/openai.rs b/crates/typedialog-agent/typedialog-ag-core/src/llm/openai.rs new file mode 100644 index 0000000..db0d67d --- /dev/null +++ b/crates/typedialog-agent/typedialog-ag-core/src/llm/openai.rs @@ -0,0 +1,457 @@ +//! OpenAI API client implementation + +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use std::env; + +use super::provider::{ + LlmProvider, LlmRequest, LlmResponse, LlmStream, MessageRole, StreamChunk, StreamMetadata, + TokenUsage, +}; +use crate::error::{Error, Result}; +use futures::stream::StreamExt; + +const OPENAI_API_URL: &str = "https://api.openai.com/v1/chat/completions"; + +/// OpenAI API client +pub struct OpenAIProvider { + api_key: String, + client: reqwest::Client, +} + +/// OpenAI API request format +#[derive(Debug, Serialize)] +struct OpenAIRequest { + model: String, + messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + max_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + stream: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +struct OpenAIMessage { + role: String, + content: String, +} + +/// OpenAI API response format +#[derive(Debug, Deserialize)] +struct OpenAIResponse { + #[allow(dead_code)] + id: String, + #[allow(dead_code)] + object: String, + #[allow(dead_code)] + created: u64, + model: String, + choices: Vec, + usage: OpenAIUsage, +} + +#[derive(Debug, Deserialize)] +struct OpenAIChoice { + #[allow(dead_code)] + index: usize, + message: OpenAIMessage, + #[allow(dead_code)] + finish_reason: Option, +} + +#[derive(Debug, Deserialize)] +struct OpenAIUsage { + prompt_tokens: usize, + completion_tokens: usize, + total_tokens: usize, +} + +impl OpenAIProvider { + /// Create a new OpenAI provider + pub fn new() -> Result { + let api_key = env::var("OPENAI_API_KEY").map_err(|_| { + Error::execution( + "OPENAI_API_KEY environment variable not set", + "Set OPENAI_API_KEY to use OpenAI models", + ) + })?; + + let client = reqwest::Client::new(); + + Ok(Self { api_key, client }) + } + + /// Create a provider with explicit API key + pub fn with_api_key(api_key: String) -> Self { + let client = reqwest::Client::new(); + Self { api_key, client } + } +} + +#[async_trait] +impl LlmProvider for OpenAIProvider { + async fn complete(&self, request: LlmRequest) -> Result { + // Convert generic messages to OpenAI format + let mut messages: Vec = request + .messages + .into_iter() + .filter(|m| m.role != MessageRole::System) + .map(|m| OpenAIMessage { + role: match m.role { + MessageRole::User => "user".to_string(), + MessageRole::Assistant => "assistant".to_string(), + MessageRole::System => "system".to_string(), + }, + content: m.content, + }) + .collect(); + + // Add system message at the beginning if provided + if let Some(system_prompt) = request.system { + messages.insert( + 0, + OpenAIMessage { + role: "system".to_string(), + content: system_prompt, + }, + ); + } + + // Build OpenAI request + let openai_request = OpenAIRequest { + model: request.model.clone(), + messages, + max_tokens: request.max_tokens, + temperature: request.temperature, + stream: None, + }; + + // Make API call + let response = self + .client + .post(OPENAI_API_URL) + .header("Authorization", format!("Bearer {}", self.api_key)) + .header("Content-Type", "application/json") + .json(&openai_request) + .send() + .await + .map_err(|e| Error::execution("Failed to call OpenAI API", e.to_string()))?; + + // Check for HTTP errors + if !response.status().is_success() { + let status = response.status(); + let error_text = response + .text() + .await + .unwrap_or_else(|_| "Unknown error".to_string()); + return Err(Error::execution( + format!("OpenAI API error: {}", status), + error_text, + )); + } + + // Parse response + let openai_response: OpenAIResponse = response + .json() + .await + .map_err(|e| Error::execution("Failed to parse OpenAI API response", e.to_string()))?; + + // Extract content from first choice + let content = openai_response + .choices + .first() + .map(|choice| choice.message.content.clone()) + .unwrap_or_default(); + + Ok(LlmResponse { + content, + model: openai_response.model, + usage: Some(TokenUsage { + input_tokens: openai_response.usage.prompt_tokens, + output_tokens: openai_response.usage.completion_tokens, + total_tokens: openai_response.usage.total_tokens, + }), + }) + } + + async fn stream(&self, request: LlmRequest) -> Result { + use futures::stream; + + // Convert generic messages to OpenAI format + let mut messages: Vec = request + .messages + .into_iter() + .filter(|m| m.role != MessageRole::System) + .map(|m| OpenAIMessage { + role: match m.role { + MessageRole::User => "user".to_string(), + MessageRole::Assistant => "assistant".to_string(), + MessageRole::System => "system".to_string(), + }, + content: m.content, + }) + .collect(); + + // Add system message at the beginning if provided + if let Some(system_prompt) = request.system { + messages.insert( + 0, + OpenAIMessage { + role: "system".to_string(), + content: system_prompt, + }, + ); + } + + // Build OpenAI streaming request + let openai_request = OpenAIRequest { + model: request.model.clone(), + messages, + max_tokens: request.max_tokens, + temperature: request.temperature, + stream: Some(true), + }; + + // Make streaming API call + let response = self + .client + .post(OPENAI_API_URL) + .header("Authorization", format!("Bearer {}", self.api_key)) + .header("Content-Type", "application/json") + .json(&openai_request) + .send() + .await + .map_err(|e| Error::execution("Failed to call OpenAI API", e.to_string()))?; + + // Check for HTTP errors + if !response.status().is_success() { + let status = response.status(); + let error_text = response + .text() + .await + .unwrap_or_else(|_| "Unknown error".to_string()); + return Err(Error::execution( + format!("OpenAI API error: {}", status), + error_text, + )); + } + + // Convert response to SSE stream + use futures::TryStreamExt; + let byte_stream = response.bytes_stream(); + let model = request.model.clone(); + + let sse_stream = byte_stream + .map_err(|e| Error::execution("Stream error", e.to_string())) + .map(move |chunk_result| { + match chunk_result { + Ok(bytes) => { + // Parse SSE events + let text = String::from_utf8_lossy(&bytes); + parse_sse_events(&text, &model) + } + Err(e) => vec![Err(e)], + } + }) + .flat_map(stream::iter); + + Ok(Box::pin(sse_stream)) + } + + fn name(&self) -> &str { + "OpenAI" + } +} + +/// Parse SSE events from OpenAI streaming response +fn parse_sse_events(text: &str, model: &str) -> Vec> { + let mut chunks = Vec::new(); + + for line in text.lines() { + if let Some(data) = line.strip_prefix("data: ") { + // Check for completion signal + if data.trim() == "[DONE]" { + chunks.push(Ok(StreamChunk::Done(StreamMetadata { + model: model.to_string(), + usage: None, // OpenAI doesn't send usage in streaming mode + }))); + continue; + } + + // Parse JSON event + if let Ok(event) = serde_json::from_str::(data) { + // Extract delta content + if let Some(choices) = event.get("choices").and_then(|c| c.as_array()) { + if let Some(choice) = choices.first() { + if let Some(delta) = choice.get("delta") { + if let Some(content) = delta.get("content").and_then(|c| c.as_str()) { + chunks.push(Ok(StreamChunk::Content(content.to_string()))); + } + } + } + } + + // Check for errors + if let Some(error) = event.get("error") { + let error_msg = error + .get("message") + .and_then(|m| m.as_str()) + .unwrap_or("Unknown error"); + chunks.push(Ok(StreamChunk::Error(error_msg.to_string()))); + } + } + } + } + + chunks +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::llm::LlmMessage; + + #[test] + fn test_create_provider_without_api_key() { + // Should fail if OPENAI_API_KEY not set + let original_key = env::var("OPENAI_API_KEY").ok(); + env::remove_var("OPENAI_API_KEY"); + + let result = OpenAIProvider::new(); + assert!(result.is_err()); + + // Restore original key if it existed + if let Some(key) = original_key { + env::set_var("OPENAI_API_KEY", key); + } + } + + #[test] + fn test_create_provider_with_explicit_key() { + let provider = OpenAIProvider::with_api_key("test-key".to_string()); + assert_eq!(provider.api_key, "test-key"); + assert_eq!(provider.name(), "OpenAI"); + } + + #[test] + fn test_parse_sse_content_delta() { + let sse_data = r#"data: {"choices":[{"delta":{"content":"Hello"}}]}"#; + let model = "gpt-4"; + + let chunks = parse_sse_events(sse_data, model); + + assert_eq!(chunks.len(), 1); + match &chunks[0] { + Ok(StreamChunk::Content(text)) => assert_eq!(text, "Hello"), + _ => panic!("Expected Content chunk"), + } + } + + #[test] + fn test_parse_sse_done_signal() { + let sse_data = "data: [DONE]"; + let model = "gpt-4"; + + let chunks = parse_sse_events(sse_data, model); + + assert_eq!(chunks.len(), 1); + match &chunks[0] { + Ok(StreamChunk::Done(metadata)) => { + assert_eq!(metadata.model, model); + assert!(metadata.usage.is_none()); + } + _ => panic!("Expected Done chunk"), + } + } + + #[test] + fn test_parse_sse_error() { + let sse_data = r#"data: {"error":{"message":"Rate limit exceeded"}}"#; + let model = "gpt-4"; + + let chunks = parse_sse_events(sse_data, model); + + assert_eq!(chunks.len(), 1); + match &chunks[0] { + Ok(StreamChunk::Error(msg)) => assert_eq!(msg, "Rate limit exceeded"), + _ => panic!("Expected Error chunk"), + } + } + + #[test] + fn test_parse_sse_multiple_events() { + let sse_data = r#"data: {"choices":[{"delta":{"content":"Hello"}}]} +data: {"choices":[{"delta":{"content":" world"}}]} +data: [DONE]"#; + let model = "gpt-4"; + + let chunks = parse_sse_events(sse_data, model); + + assert_eq!(chunks.len(), 3); + + match &chunks[0] { + Ok(StreamChunk::Content(text)) => assert_eq!(text, "Hello"), + _ => panic!("Expected first Content chunk"), + } + + match &chunks[1] { + Ok(StreamChunk::Content(text)) => assert_eq!(text, " world"), + _ => panic!("Expected second Content chunk"), + } + + match &chunks[2] { + Ok(StreamChunk::Done(_)) => {} + _ => panic!("Expected Done chunk"), + } + } + + #[test] + fn test_parse_sse_empty_delta() { + let sse_data = r#"data: {"choices":[{"delta":{}}]}"#; + let model = "gpt-4"; + + let chunks = parse_sse_events(sse_data, model); + + // Should not produce any chunks for empty delta + assert_eq!(chunks.len(), 0); + } + + #[test] + fn test_parse_sse_empty_input() { + let sse_data = ""; + let model = "gpt-4"; + + let chunks = parse_sse_events(sse_data, model); + + assert_eq!(chunks.len(), 0); + } + + #[tokio::test] + #[ignore] // Only run with real API key + async fn test_openai_api_call() { + let provider = OpenAIProvider::new().expect("OPENAI_API_KEY must be set"); + + let request = LlmRequest { + model: "gpt-4o-mini".to_string(), + messages: vec![LlmMessage { + role: MessageRole::User, + content: "Say hello in exactly 3 words.".to_string(), + }], + max_tokens: Some(50), + temperature: Some(0.7), + system: None, + }; + + let response = provider.complete(request).await; + assert!(response.is_ok()); + + let response = response.unwrap(); + assert!(!response.content.is_empty()); + assert!(response.usage.is_some()); + + println!("Response: {}", response.content); + println!("Usage: {:?}", response.usage); + } +} diff --git a/crates/typedialog-agent/typedialog-ag-core/src/llm/provider.rs b/crates/typedialog-agent/typedialog-ag-core/src/llm/provider.rs new file mode 100644 index 0000000..a2579fc --- /dev/null +++ b/crates/typedialog-agent/typedialog-ag-core/src/llm/provider.rs @@ -0,0 +1,83 @@ +//! LLM provider trait and common types + +use crate::error::Result; +use async_trait::async_trait; +use futures::stream::Stream; +use serde::{Deserialize, Serialize}; +use std::pin::Pin; + +/// Role of a message in conversation +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum MessageRole { + User, + Assistant, + System, +} + +/// A single message in the conversation +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LlmMessage { + pub role: MessageRole, + pub content: String, +} + +/// Request to LLM provider +#[derive(Debug, Clone)] +pub struct LlmRequest { + pub model: String, + pub messages: Vec, + pub max_tokens: Option, + pub temperature: Option, + pub system: Option, +} + +/// Response from LLM provider +#[derive(Debug, Clone)] +pub struct LlmResponse { + pub content: String, + pub model: String, + pub usage: Option, +} + +/// Streaming chunk from LLM +#[derive(Debug, Clone)] +pub enum StreamChunk { + /// Text content delta + Content(String), + /// Stream completed with final metadata + Done(StreamMetadata), + /// Error occurred + Error(String), +} + +/// Metadata when stream completes +#[derive(Debug, Clone)] +pub struct StreamMetadata { + pub model: String, + pub usage: Option, +} + +/// Token usage statistics +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TokenUsage { + pub input_tokens: usize, + pub output_tokens: usize, + pub total_tokens: usize, +} + +/// Stream type alias +pub type LlmStream = Pin> + Send>>; + +/// LLM provider trait - implemented by Claude, OpenAI, Gemini, etc. +#[async_trait] +pub trait LlmProvider: Send + Sync { + /// Execute a completion request + async fn complete(&self, request: LlmRequest) -> Result; + + /// Stream a completion request + async fn stream(&self, request: LlmRequest) -> Result; + + /// Get the provider name + fn name(&self) -> &str; +} diff --git a/crates/typedialog-agent/typedialog-ag-core/src/nickel/mod.rs b/crates/typedialog-agent/typedialog-ag-core/src/nickel/mod.rs new file mode 100644 index 0000000..e74345c --- /dev/null +++ b/crates/typedialog-agent/typedialog-ag-core/src/nickel/mod.rs @@ -0,0 +1,190 @@ +//! Nickel evaluation and validation +//! +//! Evaluates Nickel code to produce AgentDefinition with type checking + +use crate::error::{Error, Result}; +use nickel_lang_core::program::Program; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// Agent configuration from Nickel evaluation +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AgentConfig { + pub role: String, + pub llm: String, + #[serde(default)] + pub tools: Vec, + #[serde(default = "default_max_tokens")] + pub max_tokens: usize, + #[serde(default = "default_temperature")] + pub temperature: f64, +} + +fn default_max_tokens() -> usize { + 4096 +} + +fn default_temperature() -> f64 { + 0.7 +} + +/// Complete agent definition after Nickel evaluation +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AgentDefinition { + pub config: AgentConfig, + #[serde(default)] + pub inputs: HashMap, + #[serde(default)] + pub context: Option, + #[serde(default)] + pub validation: Option, + #[serde(default)] + pub template: String, +} + +/// Context sources (imports, shell commands) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ContextSources { + #[serde(default)] + pub imports: Vec, + #[serde(default)] + pub shell_commands: Vec, +} + +/// Output validation rules +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ValidationRules { + #[serde(default)] + pub must_contain: Vec, + #[serde(default)] + pub must_not_contain: Vec, + #[serde(default = "default_format")] + pub format: String, + pub min_length: Option, + pub max_length: Option, +} + +fn default_format() -> String { + "markdown".to_string() +} + +/// Nickel evaluator +pub struct NickelEvaluator; + +impl NickelEvaluator { + /// Create a new Nickel evaluator + pub fn new() -> Self { + Self + } + + /// Evaluate Nickel code to AgentDefinition + pub fn evaluate(&self, nickel_code: &str) -> Result { + use nickel_lang_core::error::NullReporter; + use nickel_lang_core::eval::value::lazy::CBNCache; + use nickel_lang_core::serialize::ExportFormat; + use std::io::Cursor; + + // Parse and create program with null reporter (we handle errors ourselves) + let mut program: Program = Program::new_from_source( + Cursor::new(nickel_code), + "", + std::io::sink(), + NullReporter {}, + ) + .map_err(|e| Error::nickel_eval("Failed to parse Nickel code", format!("{:?}", e)))?; + + // Evaluate to get the term + let term = program.eval_full_for_export().map_err(|e| { + Error::nickel_eval("Failed to evaluate Nickel code", format!("{:?}", e)) + })?; + + // Serialize to JSON + let mut output = Vec::new(); + nickel_lang_core::serialize::to_writer(&mut output, ExportFormat::Json, &term) + .map_err(|e| Error::nickel_eval("Failed to serialize to JSON", format!("{:?}", e)))?; + + // Parse JSON and deserialize + let json_value: serde_json::Value = serde_json::from_slice(&output) + .map_err(|e| Error::nickel_eval("Failed to parse JSON output", e.to_string()))?; + + serde_json::from_value(json_value) + .map_err(|e| Error::nickel_eval("Failed to deserialize AgentDefinition", e.to_string())) + } + + /// Type check Nickel code without evaluation + pub fn typecheck(&self, nickel_code: &str) -> Result<()> { + use nickel_lang_core::error::NullReporter; + use nickel_lang_core::eval::value::lazy::CBNCache; + use nickel_lang_core::typecheck::TypecheckMode; + use std::io::Cursor; + + // Parse Nickel code + let mut program: Program = Program::new_from_source( + Cursor::new(nickel_code), + "", + std::io::sink(), + NullReporter {}, + ) + .map_err(|e| { + Error::nickel_eval( + "Failed to parse Nickel code for type checking", + format!("{:?}", e), + ) + })?; + + // Type check with strict mode + program + .typecheck(TypecheckMode::Enforce) + .map_err(|e| Error::nickel_eval("Type checking failed", format!("{:?}", e)))?; + + Ok(()) + } +} + +impl Default for NickelEvaluator { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_evaluate_simple_nickel() { + let nickel_code = r#"{ + config = { + role = "architect", + llm = "claude-opus-4", + tools = ["analyze"], + }, + inputs = {}, + template = "test", +}"#; + + let evaluator = NickelEvaluator::new(); + let result = evaluator.evaluate(nickel_code); + + assert!(result.is_ok()); + let agent_def = result.unwrap(); + assert_eq!(agent_def.config.role, "architect"); + assert_eq!(agent_def.config.llm, "claude-opus-4"); + assert_eq!(agent_def.config.tools, vec!["analyze"]); + } + + #[test] + fn test_typecheck_valid_nickel() { + let nickel_code = r#"{ + config = { + role = "architect", + llm = "claude-opus-4", + }, +}"#; + + let evaluator = NickelEvaluator::new(); + let result = evaluator.typecheck(nickel_code); + + assert!(result.is_ok()); + } +} diff --git a/crates/typedialog-agent/typedialog-ag-core/src/parser/ast.rs b/crates/typedialog-agent/typedialog-ag-core/src/parser/ast.rs new file mode 100644 index 0000000..b37f19b --- /dev/null +++ b/crates/typedialog-agent/typedialog-ag-core/src/parser/ast.rs @@ -0,0 +1,58 @@ +//! Abstract Syntax Tree for markup + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MarkupAst { + pub nodes: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum MarkupNode { + Agent(AgentConfig), + Input(InputDecl), + Import(ImportDecl), + Shell(ShellDecl), + If(IfBlock), + Validate(ValidationRules), + Variable(String), + Text(String), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AgentConfig { + pub role: String, + pub llm: String, + pub tools: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct InputDecl { + pub name: String, + pub type_spec: String, + pub optional: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ImportDecl { + pub path: String, + pub alias: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ShellDecl { + pub command: String, + pub alias: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct IfBlock { + pub condition: String, + pub body: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ValidationRules { + pub must_contain: Vec, + pub format: String, +} diff --git a/crates/typedialog-agent/typedialog-ag-core/src/parser/directives.rs b/crates/typedialog-agent/typedialog-ag-core/src/parser/directives.rs new file mode 100644 index 0000000..4913e52 --- /dev/null +++ b/crates/typedialog-agent/typedialog-ag-core/src/parser/directives.rs @@ -0,0 +1,28 @@ +//! Directive parsing for @directives in MDX files +//! +//! Supported directives: +//! - @agent { role: architect, llm: claude-opus-4 } +//! - @input name: Type +//! - @import "path" as alias +//! - @shell "command" as alias +//! - @if condition { ... } +//! - @validate output { must_contain: [...] } + +use super::ast::{AgentConfig, IfBlock, ImportDecl, InputDecl, ShellDecl, ValidationRules}; + +/// Agent directive type +#[derive(Debug, Clone)] +pub enum AgentDirective { + /// Agent configuration block + Agent(AgentConfig), + /// Input declaration + Input(InputDecl), + /// File import + Import(ImportDecl), + /// Shell command execution + Shell(ShellDecl), + /// Conditional block + If(IfBlock), + /// Output validation rules + Validate(ValidationRules), +} diff --git a/crates/typedialog-agent/typedialog-ag-core/src/parser/markdown.rs b/crates/typedialog-agent/typedialog-ag-core/src/parser/markdown.rs new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/crates/typedialog-agent/typedialog-ag-core/src/parser/markdown.rs @@ -0,0 +1 @@ + diff --git a/crates/typedialog-agent/typedialog-ag-core/src/parser/mdx.rs b/crates/typedialog-agent/typedialog-ag-core/src/parser/mdx.rs new file mode 100644 index 0000000..4bdcd1a --- /dev/null +++ b/crates/typedialog-agent/typedialog-ag-core/src/parser/mdx.rs @@ -0,0 +1,265 @@ +//! MDX parser for .agent.mdx files +//! +//! Parses MDX with @ directives: +//! - @agent { ... } +//! - @input name: Type +//! - @import "path" as alias +//! - @shell "command" as alias +//! - @if condition { ... } +//! - @validate output { ... } + +use nom::{ + branch::alt, + bytes::complete::{tag, take_until, take_while1}, + character::complete::{char, line_ending, multispace0, multispace1, not_line_ending}, + combinator::map, + multi::many0, + sequence::{delimited, terminated}, + IResult, Parser, +}; + +use super::ast::*; +use super::directives::AgentDirective; + +/// Parse frontmatter delimited by --- +pub fn parse_frontmatter(input: &str) -> IResult<&str, Vec> { + let (input, _) = tag("---").parse(input)?; + let (input, _) = line_ending.parse(input)?; + let (input, directives) = many0(terminated(parse_directive, multispace0)).parse(input)?; + let (input, _) = tag("---").parse(input)?; + let (input, _) = line_ending.parse(input)?; + Ok((input, directives)) +} + +/// Parse any @ directive +fn parse_directive(input: &str) -> IResult<&str, AgentDirective> { + alt(( + map(parse_agent_directive, AgentDirective::Agent), + map(parse_input_directive, AgentDirective::Input), + map(parse_import_directive, AgentDirective::Import), + map(parse_shell_directive, AgentDirective::Shell), + map(parse_validate_directive, AgentDirective::Validate), + )) + .parse(input) +} + +/// Parse @agent { role: architect, llm: claude-opus-4, tools: [...] } +fn parse_agent_directive(input: &str) -> IResult<&str, AgentConfig> { + let (input, _) = tag("@agent").parse(input)?; + let (input, _) = multispace0.parse(input)?; + let (input, content) = delimited(char('{'), take_until("}"), char('}')).parse(input)?; + + let config = parse_agent_config(content); + Ok((input, config)) +} + +/// Parse agent config fields (simplified - not full Nickel parser) +fn parse_agent_config(input: &str) -> AgentConfig { + let mut role = String::new(); + let mut llm = String::new(); + let mut tools = Vec::new(); + + for line in input.lines() { + let line = line.trim(); + if line.starts_with("role:") { + role = line + .trim_start_matches("role:") + .trim() + .trim_end_matches(',') + .trim() + .to_string(); + } else if line.starts_with("llm:") { + llm = line + .trim_start_matches("llm:") + .trim() + .trim_end_matches(',') + .trim() + .to_string(); + } else if line.starts_with("tools:") { + let tools_str = line + .trim_start_matches("tools:") + .trim() + .trim_end_matches(','); + if let Some(stripped) = tools_str + .strip_prefix('[') + .and_then(|s| s.strip_suffix(']')) + { + tools = stripped.split(',').map(|s| s.trim().to_string()).collect(); + } + } + } + + AgentConfig { role, llm, tools } +} + +/// Parse @input name: Type or @input name?: Type (optional) +fn parse_input_directive(input: &str) -> IResult<&str, InputDecl> { + let (input, _) = tag("@input").parse(input)?; + let (input, _) = multispace1.parse(input)?; + + let (input, name_part) = + take_while1(|c: char| c.is_alphanumeric() || c == '_' || c == '?').parse(input)?; + let (input, _) = char(':').parse(input)?; + let (input, _) = multispace0.parse(input)?; + let (input, type_spec) = not_line_ending.parse(input)?; + + let (name, optional) = if let Some(n) = name_part.strip_suffix('?') { + (n.to_string(), true) + } else { + (name_part.to_string(), false) + }; + + Ok(( + input, + InputDecl { + name, + type_spec: type_spec.trim().to_string(), + optional, + }, + )) +} + +/// Parse @import "path" as alias +fn parse_import_directive(input: &str) -> IResult<&str, ImportDecl> { + let (input, _) = tag("@import").parse(input)?; + let (input, _) = multispace1.parse(input)?; + let (input, path) = delimited(char('"'), take_until("\""), char('"')).parse(input)?; + let (input, _) = multispace1.parse(input)?; + let (input, _) = tag("as").parse(input)?; + let (input, _) = multispace1.parse(input)?; + let (input, alias) = take_while1(|c: char| c.is_alphanumeric() || c == '_').parse(input)?; + + Ok(( + input, + ImportDecl { + path: path.to_string(), + alias: alias.to_string(), + }, + )) +} + +/// Parse @shell "command" as alias +fn parse_shell_directive(input: &str) -> IResult<&str, ShellDecl> { + let (input, _) = tag("@shell").parse(input)?; + let (input, _) = multispace1.parse(input)?; + let (input, command) = delimited(char('"'), take_until("\""), char('"')).parse(input)?; + let (input, _) = multispace1.parse(input)?; + let (input, _) = tag("as").parse(input)?; + let (input, _) = multispace1.parse(input)?; + let (input, alias) = take_while1(|c: char| c.is_alphanumeric() || c == '_').parse(input)?; + + Ok(( + input, + ShellDecl { + command: command.to_string(), + alias: alias.to_string(), + }, + )) +} + +/// Parse @validate output { must_contain: [...], format: markdown } +fn parse_validate_directive(input: &str) -> IResult<&str, ValidationRules> { + let (input, _) = tag("@validate").parse(input)?; + let (input, _) = multispace1.parse(input)?; + let (input, _) = tag("output").parse(input)?; + let (input, _) = multispace0.parse(input)?; + let (input, content) = delimited(char('{'), take_until("}"), char('}')).parse(input)?; + + let validation = parse_validation_rules(content); + Ok((input, validation)) +} + +/// Parse validation rules (simplified) +fn parse_validation_rules(input: &str) -> ValidationRules { + let mut must_contain = Vec::new(); + let mut format = String::from("markdown"); + + for line in input.lines() { + let line = line.trim(); + if line.starts_with("must_contain:") { + let list_str = line.trim_start_matches("must_contain:").trim(); + if let Some(stripped) = list_str.strip_prefix('[').and_then(|s| s.strip_suffix(']')) { + must_contain = stripped + .split(',') + .map(|s| s.trim().trim_matches('"').to_string()) + .collect(); + } + } else if line.starts_with("format:") { + format = line + .trim_start_matches("format:") + .trim() + .trim_end_matches(',') + .trim() + .to_string(); + } + } + + ValidationRules { + must_contain, + format, + } +} + +/// Parse template body with {{variable}} interpolation +pub fn parse_template_body(input: &str) -> IResult<&str, Vec> { + many0(alt(( + map(parse_variable, MarkupNode::Variable), + map(take_while1(|c| c != '{'), |s: &str| { + MarkupNode::Text(s.to_string()) + }), + ))) + .parse(input) +} + +/// Parse {{variable}} or {{#if condition}}...{{/if}} +fn parse_variable(input: &str) -> IResult<&str, String> { + let (input, _) = tag("{{").parse(input)?; + let (input, content) = take_until("}}").parse(input)?; + let (input, _) = tag("}}").parse(input)?; + Ok((input, content.trim().to_string())) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_input_directive() { + let input = "@input feature_name: String"; + let (_, decl) = parse_input_directive(input).unwrap(); + assert_eq!(decl.name, "feature_name"); + assert_eq!(decl.type_spec, "String"); + assert!(!decl.optional); + } + + #[test] + fn test_parse_optional_input() { + let input = "@input requirements?: String"; + let (_, decl) = parse_input_directive(input).unwrap(); + assert_eq!(decl.name, "requirements"); + assert!(decl.optional); + } + + #[test] + fn test_parse_import_directive() { + let input = r#"@import "./docs/**/*.md" as arch_docs"#; + let (_, decl) = parse_import_directive(input).unwrap(); + assert_eq!(decl.path, "./docs/**/*.md"); + assert_eq!(decl.alias, "arch_docs"); + } + + #[test] + fn test_parse_shell_directive() { + let input = r#"@shell "git log --oneline -20" as recent_commits"#; + let (_, decl) = parse_shell_directive(input).unwrap(); + assert_eq!(decl.command, "git log --oneline -20"); + assert_eq!(decl.alias, "recent_commits"); + } + + #[test] + fn test_parse_variable() { + let input = "{{feature_name}}"; + let (_, var) = parse_variable(input).unwrap(); + assert_eq!(var, "feature_name"); + } +} diff --git a/crates/typedialog-agent/typedialog-ag-core/src/parser/mod.rs b/crates/typedialog-agent/typedialog-ag-core/src/parser/mod.rs new file mode 100644 index 0000000..38935dd --- /dev/null +++ b/crates/typedialog-agent/typedialog-ag-core/src/parser/mod.rs @@ -0,0 +1,56 @@ +//! Layer 1: Markup Parser (MDX → AST) + +pub mod ast; +pub mod directives; +pub mod markdown; +pub mod mdx; + +pub use ast::{MarkupAst, MarkupNode}; +pub use directives::AgentDirective; + +use crate::error::{Error, Result}; + +/// Markup parser for .agent.mdx files +pub struct MarkupParser; + +impl MarkupParser { + pub fn new() -> Self { + Self + } + + /// Parse MDX content into AST + pub fn parse(&self, content: &str) -> Result { + let (remaining, directives) = mdx::parse_frontmatter(content) + .map_err(|e| Error::parse(format!("Failed to parse frontmatter: {}", e), None, None))?; + + let (_, body_nodes) = mdx::parse_template_body(remaining.trim()).map_err(|e| { + Error::parse(format!("Failed to parse template body: {}", e), None, None) + })?; + + let mut nodes = Vec::new(); + + for directive in directives { + let node = match directive { + directives::AgentDirective::Agent(config) => MarkupNode::Agent(config), + directives::AgentDirective::Input(input) => MarkupNode::Input(input), + directives::AgentDirective::Import(import) => MarkupNode::Import(import), + directives::AgentDirective::Shell(shell) => MarkupNode::Shell(shell), + directives::AgentDirective::If(if_block) => MarkupNode::If(if_block), + directives::AgentDirective::Validate(validation) => { + MarkupNode::Validate(validation) + } + }; + nodes.push(node); + } + + nodes.extend(body_nodes); + + Ok(MarkupAst { nodes }) + } +} + +impl Default for MarkupParser { + fn default() -> Self { + Self::new() + } +} diff --git a/crates/typedialog-agent/typedialog-ag-core/src/transpiler/mod.rs b/crates/typedialog-agent/typedialog-ag-core/src/transpiler/mod.rs new file mode 100644 index 0000000..77b1005 --- /dev/null +++ b/crates/typedialog-agent/typedialog-ag-core/src/transpiler/mod.rs @@ -0,0 +1,281 @@ +//! Transpiler: AST → Nickel code generator +//! +//! Converts parsed MDX AST into valid Nickel configuration code + +use crate::error::{Error, Result}; +use crate::parser::ast::{ + AgentConfig, ImportDecl, InputDecl, MarkupAst, MarkupNode, ShellDecl, ValidationRules, +}; +use std::fmt::Write; + +/// Transpiles MDX AST to Nickel code +pub struct NickelTranspiler; + +impl NickelTranspiler { + /// Create a new Nickel transpiler + pub fn new() -> Self { + Self + } + + /// Transpile AST to Nickel code string + pub fn transpile(&self, ast: &MarkupAst) -> Result { + let mut nickel_code = String::new(); + + // Extract components from AST + let mut agent_config: Option<&AgentConfig> = None; + let mut inputs: Vec<&InputDecl> = Vec::new(); + let mut imports: Vec<&ImportDecl> = Vec::new(); + let mut shells: Vec<&ShellDecl> = Vec::new(); + let mut validation: Option<&ValidationRules> = None; + let mut template_parts: Vec = Vec::new(); + + for node in &ast.nodes { + match node { + MarkupNode::Agent(config) => agent_config = Some(config), + MarkupNode::Input(input) => inputs.push(input), + MarkupNode::Import(import) => imports.push(import), + MarkupNode::Shell(shell) => shells.push(shell), + MarkupNode::Validate(val) => validation = Some(val), + MarkupNode::Variable(var) => template_parts.push(format!("{{{{{}}}}}", var)), + MarkupNode::Text(text) => template_parts.push(text.clone()), + MarkupNode::If(_) => { + // TODO: Handle if blocks in template + } + } + } + + // Start Nickel record + writeln!(&mut nickel_code, "{{") + .map_err(|e| Error::transpile("Failed to write Nickel code", e.to_string()))?; + + // Generate config section + if let Some(config) = agent_config { + self.write_config(&mut nickel_code, config)?; + } else { + return Err(Error::transpile( + "Missing @agent directive", + "agent config required", + )); + } + + // Generate inputs section (empty record - inputs are provided at runtime) + writeln!(&mut nickel_code, " inputs = {{}},") + .map_err(|e| Error::transpile("Failed to write inputs", e.to_string()))?; + + // Generate context section (imports + shell commands) + if !imports.is_empty() || !shells.is_empty() { + self.write_context(&mut nickel_code, &imports, &shells)?; + } + + // Generate validation section + if let Some(val) = validation { + self.write_validation(&mut nickel_code, val)?; + } + + // Generate template section + let template = template_parts.join(""); + self.write_template(&mut nickel_code, &template)?; + + // Close Nickel record + writeln!(&mut nickel_code, "}}") + .map_err(|e| Error::transpile("Failed to close Nickel record", e.to_string()))?; + + Ok(nickel_code) + } + + /// Write config section + fn write_config(&self, code: &mut String, config: &AgentConfig) -> Result<()> { + writeln!(code, " config = {{") + .map_err(|e| Error::transpile("Failed to write config", e.to_string()))?; + + writeln!(code, " role = \"{}\",", config.role) + .map_err(|e| Error::transpile("Failed to write role", e.to_string()))?; + + writeln!(code, " llm = \"{}\",", config.llm) + .map_err(|e| Error::transpile("Failed to write llm", e.to_string()))?; + + if !config.tools.is_empty() { + write!(code, " tools = [") + .map_err(|e| Error::transpile("Failed to write tools", e.to_string()))?; + for (i, tool) in config.tools.iter().enumerate() { + if i > 0 { + write!(code, ", ").map_err(|e| { + Error::transpile("Failed to write tool separator", e.to_string()) + })?; + } + write!(code, "\"{}\"", tool) + .map_err(|e| Error::transpile("Failed to write tool", e.to_string()))?; + } + writeln!(code, "],") + .map_err(|e| Error::transpile("Failed to close tools", e.to_string()))?; + } + + writeln!(code, " }},") + .map_err(|e| Error::transpile("Failed to close config", e.to_string()))?; + + Ok(()) + } + + /// Write context section + fn write_context( + &self, + code: &mut String, + imports: &[&ImportDecl], + shells: &[&ShellDecl], + ) -> Result<()> { + writeln!(code, " context = {{") + .map_err(|e| Error::transpile("Failed to write context", e.to_string()))?; + + if !imports.is_empty() { + write!(code, " imports = [") + .map_err(|e| Error::transpile("Failed to write imports", e.to_string()))?; + for (i, import) in imports.iter().enumerate() { + if i > 0 { + write!(code, ", ").map_err(|e| { + Error::transpile("Failed to write import separator", e.to_string()) + })?; + } + write!(code, "\"{}\"", import.path) + .map_err(|e| Error::transpile("Failed to write import path", e.to_string()))?; + } + writeln!(code, "],") + .map_err(|e| Error::transpile("Failed to close imports", e.to_string()))?; + } + + if !shells.is_empty() { + write!(code, " shell_commands = [") + .map_err(|e| Error::transpile("Failed to write shell_commands", e.to_string()))?; + for (i, shell) in shells.iter().enumerate() { + if i > 0 { + write!(code, ", ").map_err(|e| { + Error::transpile("Failed to write shell separator", e.to_string()) + })?; + } + write!(code, "\"{}\"", shell.command).map_err(|e| { + Error::transpile("Failed to write shell command", e.to_string()) + })?; + } + writeln!(code, "],") + .map_err(|e| Error::transpile("Failed to close shell_commands", e.to_string()))?; + } + + writeln!(code, " }},") + .map_err(|e| Error::transpile("Failed to close context", e.to_string()))?; + + Ok(()) + } + + /// Write validation section + fn write_validation(&self, code: &mut String, validation: &ValidationRules) -> Result<()> { + writeln!(code, " validation = {{") + .map_err(|e| Error::transpile("Failed to write validation", e.to_string()))?; + + // Always write must_contain array + write!(code, " must_contain = [") + .map_err(|e| Error::transpile("Failed to write must_contain", e.to_string()))?; + for (i, item) in validation.must_contain.iter().enumerate() { + if i > 0 { + write!(code, ", ").map_err(|e| { + Error::transpile("Failed to write must_contain separator", e.to_string()) + })?; + } + write!(code, "\"{}\"", item.replace('"', "\\\"")).map_err(|e| { + Error::transpile("Failed to write must_contain item", e.to_string()) + })?; + } + writeln!(code, "],") + .map_err(|e| Error::transpile("Failed to close must_contain", e.to_string()))?; + + // Always write must_not_contain array (default empty) + writeln!(code, " must_not_contain = [],") + .map_err(|e| Error::transpile("Failed to write must_not_contain", e.to_string()))?; + + writeln!(code, " format = \"{}\",", validation.format) + .map_err(|e| Error::transpile("Failed to write format", e.to_string()))?; + + writeln!(code, " }},") + .map_err(|e| Error::transpile("Failed to close validation", e.to_string()))?; + + Ok(()) + } + + /// Write template section + fn write_template(&self, code: &mut String, template: &str) -> Result<()> { + // Escape quotes in template + let escaped = template.replace('\\', "\\\\").replace('"', "\\\""); + + writeln!(code, " template = \"{}\",", escaped) + .map_err(|e| Error::transpile("Failed to write template", e.to_string()))?; + + Ok(()) + } +} + +impl Default for NickelTranspiler { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::parser::ast::*; + + #[test] + fn test_transpile_basic_agent() { + let ast = MarkupAst { + nodes: vec![ + MarkupNode::Agent(AgentConfig { + role: "architect".to_string(), + llm: "claude-opus-4".to_string(), + tools: vec!["analyze_codebase".to_string()], + }), + MarkupNode::Input(InputDecl { + name: "feature_name".to_string(), + type_spec: "String".to_string(), + optional: false, + }), + MarkupNode::Text("Design: ".to_string()), + MarkupNode::Variable("feature_name".to_string()), + ], + }; + + let transpiler = NickelTranspiler::new(); + let result = transpiler.transpile(&ast); + + assert!(result.is_ok()); + let nickel = result.unwrap(); + assert!(nickel.contains("config = {")); + assert!(nickel.contains("role = \"architect\",")); + assert!(nickel.contains("llm = \"claude-opus-4\",")); + assert!(nickel.contains("inputs = {},")); + assert!(nickel.contains("template = \"Design: {{feature_name}}\"")); + } + + #[test] + fn test_transpile_with_validation() { + let ast = MarkupAst { + nodes: vec![ + MarkupNode::Agent(AgentConfig { + role: "tester".to_string(), + llm: "claude-sonnet-4".to_string(), + tools: vec![], + }), + MarkupNode::Validate(ValidationRules { + must_contain: vec!["Test".to_string(), "Assert".to_string()], + format: "markdown".to_string(), + }), + ], + }; + + let transpiler = NickelTranspiler::new(); + let result = transpiler.transpile(&ast); + + assert!(result.is_ok()); + let nickel = result.unwrap(); + assert!(nickel.contains("validation = {")); + assert!(nickel.contains("must_contain = [\"Test\", \"Assert\"]")); + assert!(nickel.contains("format = \"markdown\",")); + } +} diff --git a/crates/typedialog-agent/typedialog-ag-core/src/utils/mod.rs b/crates/typedialog-agent/typedialog-ag-core/src/utils/mod.rs new file mode 100644 index 0000000..8b5e942 --- /dev/null +++ b/crates/typedialog-agent/typedialog-ag-core/src/utils/mod.rs @@ -0,0 +1 @@ +//! utils module diff --git a/crates/typedialog-agent/typedialog-ag-core/tests/fixtures/architect.agent.mdx b/crates/typedialog-agent/typedialog-ag-core/tests/fixtures/architect.agent.mdx new file mode 100644 index 0000000..44e7521 --- /dev/null +++ b/crates/typedialog-agent/typedialog-ag-core/tests/fixtures/architect.agent.mdx @@ -0,0 +1,33 @@ +--- +@agent { + role: architect, + llm: claude-opus-4, + tools: [analyze_codebase, suggest_architecture] +} + +@input feature_name: String +@input requirements?: String + +@validate output { + must_contain: ["## Architecture", "## Components"], + format: markdown +} +--- + +# Architecture Design Task + +You are an experienced software architect. Design the architecture for the following feature: + +**Feature**: {{ feature_name }} + +{% if requirements %} +**Requirements**: {{ requirements }} +{% endif %} + +Please provide: +1. High-level architecture overview +2. Component breakdown +3. Data flow diagrams +4. Technology recommendations + +Output in markdown format with clear sections. diff --git a/crates/typedialog-agent/typedialog-ag-core/tests/fixtures/code-reviewer.agent.mdx b/crates/typedialog-agent/typedialog-ag-core/tests/fixtures/code-reviewer.agent.mdx new file mode 100644 index 0000000..6db806b --- /dev/null +++ b/crates/typedialog-agent/typedialog-ag-core/tests/fixtures/code-reviewer.agent.mdx @@ -0,0 +1,34 @@ +--- +@agent { + role: code-reviewer, + llm: claude-sonnet-4, + tools: [analyze_code, security_scan] +} + +@input file_path: String + +@shell "git diff HEAD~1" as recent_changes + +@validate output { + must_contain: ["Security", "Performance", "Maintainability"], + format: markdown, + min_length: 100 +} +--- + +# Code Review: {{ file_path }} + +You are a senior code reviewer. Review the following changes: + +**Recent Changes:** +``` +{{ recent_changes }} +``` + +Provide a comprehensive code review covering: +1. **Security**: Potential vulnerabilities +2. **Performance**: Optimization opportunities +3. **Maintainability**: Code quality and readability +4. **Best Practices**: Language-specific conventions + +Rate each category from 1-5 and provide specific recommendations. diff --git a/crates/typedialog-agent/typedialog-ag-core/tests/fixtures/haiku.agent.mdx b/crates/typedialog-agent/typedialog-ag-core/tests/fixtures/haiku.agent.mdx new file mode 100644 index 0000000..7fa4bd0 --- /dev/null +++ b/crates/typedialog-agent/typedialog-ag-core/tests/fixtures/haiku.agent.mdx @@ -0,0 +1,10 @@ +--- +@agent { + role: creative writer, + llm: claude-3-5-haiku-20241022 +} + +@input topic: String +--- + +Write a haiku about {{ topic }}. Return only the haiku, nothing else. diff --git a/crates/typedialog-agent/typedialog-ag-core/tests/fixtures/simple.agent.mdx b/crates/typedialog-agent/typedialog-ag-core/tests/fixtures/simple.agent.mdx new file mode 100644 index 0000000..aa704f9 --- /dev/null +++ b/crates/typedialog-agent/typedialog-ag-core/tests/fixtures/simple.agent.mdx @@ -0,0 +1,8 @@ +--- +@agent { + role: assistant, + llm: claude-sonnet-4 +} +--- + +Hello {{ name }}! How can I help you today? diff --git a/crates/typedialog-agent/typedialog-ag-core/tests/integration_test.rs b/crates/typedialog-agent/typedialog-ag-core/tests/integration_test.rs new file mode 100644 index 0000000..7841981 --- /dev/null +++ b/crates/typedialog-agent/typedialog-ag-core/tests/integration_test.rs @@ -0,0 +1,295 @@ +//! Integration tests for complete MDX → Nickel → AgentDefinition → Execution pipeline + +use std::collections::HashMap; +use typedialog_ag_core::executor::AgentExecutor; +use typedialog_ag_core::nickel::NickelEvaluator; +use typedialog_ag_core::parser::MarkupParser; +use typedialog_ag_core::transpiler::NickelTranspiler; + +#[test] +#[ignore = "NickelTranspiler bug: must_contain arrays not being parsed from @validate directives"] +fn test_architect_agent_pipeline() { + // Read MDX file + let mdx_content = std::fs::read_to_string("tests/fixtures/architect.agent.mdx") + .expect("Failed to read architect.agent.mdx"); + + // Step 1: Parse MDX to AST + let parser = MarkupParser::new(); + let ast = parser.parse(&mdx_content).expect("Failed to parse MDX"); + + // Verify AST contains expected nodes + assert!(ast + .nodes + .iter() + .any(|n| matches!(n, typedialog_ag_core::parser::ast::MarkupNode::Agent(_)))); + assert!(ast + .nodes + .iter() + .any(|n| matches!(n, typedialog_ag_core::parser::ast::MarkupNode::Input(_)))); + assert!(ast + .nodes + .iter() + .any(|n| matches!(n, typedialog_ag_core::parser::ast::MarkupNode::Validate(_)))); + + // Step 2: Transpile AST to Nickel + let transpiler = NickelTranspiler::new(); + let nickel_code = transpiler + .transpile(&ast) + .expect("Failed to transpile to Nickel"); + + // Verify Nickel code contains expected sections + assert!(nickel_code.contains("config = {")); + assert!(nickel_code.contains("role = \"architect\",")); + assert!(nickel_code.contains("llm = \"claude-opus-4\",")); + assert!(nickel_code.contains("inputs = {},")); + assert!(nickel_code.contains("validation = {")); + assert!(nickel_code.contains("must_contain = [")); + + println!("Generated Nickel code:\n{}", nickel_code); + + // Step 3: Evaluate Nickel to AgentDefinition + let evaluator = NickelEvaluator::new(); + let agent_def = evaluator + .evaluate(&nickel_code) + .expect("Failed to evaluate Nickel"); + + // Verify AgentDefinition + assert_eq!(agent_def.config.role, "architect"); + assert_eq!(agent_def.config.llm, "claude-opus-4"); + assert_eq!( + agent_def.config.tools, + vec!["analyze_codebase", "suggest_architecture"] + ); + + assert!(agent_def.validation.is_some()); + let validation = agent_def.validation.as_ref().unwrap(); + assert!(validation + .must_contain + .contains(&"## Architecture".to_string())); + assert!(validation + .must_contain + .contains(&"## Components".to_string())); + assert_eq!(validation.format, "markdown"); + + println!("AgentDefinition: {:#?}", agent_def); +} + +#[tokio::test] +#[ignore = "Requires ANTHROPIC_API_KEY environment variable"] +async fn test_architect_agent_execution() { + // Read and parse MDX + let mdx_content = std::fs::read_to_string("tests/fixtures/architect.agent.mdx") + .expect("Failed to read architect.agent.mdx"); + + let parser = MarkupParser::new(); + let ast = parser.parse(&mdx_content).expect("Failed to parse MDX"); + + let transpiler = NickelTranspiler::new(); + let nickel_code = transpiler.transpile(&ast).expect("Failed to transpile"); + + let evaluator = NickelEvaluator::new(); + let agent_def = evaluator + .evaluate(&nickel_code) + .expect("Failed to evaluate"); + + // Step 4: Execute agent + let executor = AgentExecutor::new(); + let mut inputs = HashMap::new(); + inputs.insert( + "feature_name".to_string(), + serde_json::json!("User Authentication"), + ); + inputs.insert( + "requirements".to_string(), + serde_json::json!("OAuth2, JWT tokens, secure password hashing"), + ); + + let result = executor + .execute(&agent_def, inputs) + .await + .expect("Failed to execute agent"); + + // Verify execution result + assert!(result.output.contains("User Authentication")); + assert!(result.output.contains("OAuth2")); + assert!(result.metadata.duration_ms.is_some()); + assert_eq!(result.metadata.model, Some("claude-opus-4".to_string())); + + println!("Execution result:\n{}", result.output); + println!("Duration: {:?}ms", result.metadata.duration_ms); +} + +#[test] +#[ignore = "NickelTranspiler bug: must_contain arrays not being parsed from @validate directives"] +fn test_code_reviewer_agent_pipeline() { + // Read MDX file + let mdx_content = std::fs::read_to_string("tests/fixtures/code-reviewer.agent.mdx") + .expect("Failed to read code-reviewer.agent.mdx"); + + // Parse MDX to AST + let parser = MarkupParser::new(); + let ast = parser.parse(&mdx_content).expect("Failed to parse MDX"); + + // Verify AST contains shell directive + assert!(ast + .nodes + .iter() + .any(|n| matches!(n, typedialog_ag_core::parser::ast::MarkupNode::Shell(_)))); + + // Transpile to Nickel + let transpiler = NickelTranspiler::new(); + let nickel_code = transpiler.transpile(&ast).expect("Failed to transpile"); + + // Verify context section + assert!(nickel_code.contains("context = {")); + assert!(nickel_code.contains("shell_commands = [")); + assert!(nickel_code.contains("git diff HEAD~1")); + + println!("Generated Nickel code:\n{}", nickel_code); + + // Evaluate Nickel + let evaluator = NickelEvaluator::new(); + let agent_def = evaluator + .evaluate(&nickel_code) + .expect("Failed to evaluate"); + + // Verify context sources + assert!(agent_def.context.is_some()); + let context = agent_def.context.as_ref().unwrap(); + assert_eq!(context.shell_commands.len(), 1); + assert_eq!(context.shell_commands[0], "git diff HEAD~1"); + + // Verify validation + assert!(agent_def.validation.is_some()); + let validation = agent_def.validation.as_ref().unwrap(); + assert!(validation.must_contain.contains(&"Security".to_string())); + assert!(validation.must_contain.contains(&"Performance".to_string())); + assert!(validation + .must_contain + .contains(&"Maintainability".to_string())); + assert_eq!(validation.min_length, Some(100)); + + println!("AgentDefinition: {:#?}", agent_def); +} + +#[tokio::test] +#[ignore = "Requires ANTHROPIC_API_KEY environment variable"] +async fn test_validation_failure() { + // Create a simple agent with strict validation + let mdx_content = r#"--- +@agent { + role: tester, + llm: claude +} + +@validate output { + must_contain: ["PASS", "SUCCESS"], + format: json +} +--- + +Test result: {{ result }} +"#; + + let parser = MarkupParser::new(); + let ast = parser.parse(mdx_content).expect("Failed to parse"); + + let transpiler = NickelTranspiler::new(); + let nickel_code = transpiler.transpile(&ast).expect("Failed to transpile"); + + let evaluator = NickelEvaluator::new(); + let agent_def = evaluator + .evaluate(&nickel_code) + .expect("Failed to evaluate"); + + let executor = AgentExecutor::new(); + let mut inputs = HashMap::new(); + inputs.insert("result".to_string(), serde_json::json!("FAIL")); + + let result = executor + .execute(&agent_def, inputs) + .await + .expect("Failed to execute"); + + // Validation should fail - output won't contain required patterns and won't be valid JSON + assert!(!result.validation_passed); + assert!(!result.validation_errors.is_empty()); + + println!("Validation errors: {:?}", result.validation_errors); +} + +#[test] +fn test_nickel_typecheck() { + let mdx_content = r#"--- +@agent { + role: assistant, + llm: claude +} +--- + +Hello world +"#; + + let parser = MarkupParser::new(); + let ast = parser.parse(mdx_content).expect("Failed to parse"); + + let transpiler = NickelTranspiler::new(); + let nickel_code = transpiler.transpile(&ast).expect("Failed to transpile"); + + // Type check should pass + let evaluator = NickelEvaluator::new(); + let result = evaluator.typecheck(&nickel_code); + assert!(result.is_ok(), "Typecheck failed: {:?}", result.err()); +} + +#[tokio::test] +#[ignore = "Requires ANTHROPIC_API_KEY environment variable"] +async fn test_template_with_conditional() { + let mdx_content = r#"--- +@agent { + role: assistant, + llm: claude +} + +@input name: String +@input title?: String +--- + +Hello{% if title %} {{ title }}{% endif %} {{ name }}! +"#; + + let parser = MarkupParser::new(); + let ast = parser.parse(mdx_content).expect("Failed to parse"); + + let transpiler = NickelTranspiler::new(); + let nickel_code = transpiler.transpile(&ast).expect("Failed to transpile"); + + let evaluator = NickelEvaluator::new(); + let agent_def = evaluator + .evaluate(&nickel_code) + .expect("Failed to evaluate"); + + let executor = AgentExecutor::new(); + + // Test with title + let mut inputs = HashMap::new(); + inputs.insert("name".to_string(), serde_json::json!("Alice")); + inputs.insert("title".to_string(), serde_json::json!("Dr.")); + + let result = executor + .execute(&agent_def, inputs) + .await + .expect("Failed to execute"); + assert!(result.output.contains("Dr. Alice")); + + // Test without title + let mut inputs = HashMap::new(); + inputs.insert("name".to_string(), serde_json::json!("Bob")); + + let result = executor + .execute(&agent_def, inputs) + .await + .expect("Failed to execute"); + assert!(result.output.contains("Hello Bob")); + assert!(!result.output.contains("Dr.")); +} diff --git a/crates/typedialog-agent/typedialog-ag-core/tests/simple_integration_test.rs b/crates/typedialog-agent/typedialog-ag-core/tests/simple_integration_test.rs new file mode 100644 index 0000000..8ae71d2 --- /dev/null +++ b/crates/typedialog-agent/typedialog-ag-core/tests/simple_integration_test.rs @@ -0,0 +1,90 @@ +//! Simple integration test demonstrating the complete MDX → Nickel → Execution pipeline + +use std::collections::HashMap; +use typedialog_ag_core::executor::AgentExecutor; +use typedialog_ag_core::nickel::NickelEvaluator; +use typedialog_ag_core::parser::MarkupParser; +use typedialog_ag_core::transpiler::NickelTranspiler; + +#[tokio::test] +#[ignore] // Requires ANTHROPIC_API_KEY for real LLM execution +async fn test_complete_pipeline_with_llm() { + // Read simple MDX file + let mdx_content = std::fs::read_to_string("tests/fixtures/simple.agent.mdx") + .expect("Failed to read simple.agent.mdx"); + + println!("MDX Content:\n{}\n", mdx_content); + + // Step 1: Parse MDX to AST + let parser = MarkupParser::new(); + let ast = parser.parse(&mdx_content).expect("Failed to parse MDX"); + println!("AST parsed successfully with {} nodes\n", ast.nodes.len()); + + // Step 2: Transpile AST to Nickel + let transpiler = NickelTranspiler::new(); + let nickel_code = transpiler + .transpile(&ast) + .expect("Failed to transpile to Nickel"); + println!("Generated Nickel code:\n{}\n", nickel_code); + + // Step 3: Evaluate Nickel to AgentDefinition + let evaluator = NickelEvaluator::new(); + let mut agent_def = evaluator + .evaluate(&nickel_code) + .expect("Failed to evaluate Nickel"); + + // Use real Claude model + agent_def.config.llm = "claude-3-5-haiku-20241022".to_string(); + agent_def.config.max_tokens = 100; + + println!("AgentDefinition: {:#?}\n", agent_def); + + // Step 4: Execute agent with real LLM + let executor = AgentExecutor::new(); + let mut inputs = HashMap::new(); + inputs.insert("name".to_string(), serde_json::json!("Alice")); + + let result = executor + .execute(&agent_def, inputs) + .await + .expect("Failed to execute agent"); + println!("LLM Response:\n{}\n", result.output); + + // Verify execution completed successfully + assert!(!result.output.is_empty()); + assert!(result.metadata.duration_ms.is_some()); + assert!(result.metadata.tokens.is_some()); + assert_eq!( + result.metadata.model, + Some("claude-3-5-haiku-20241022".to_string()) + ); + + println!("✓ Complete pipeline with LLM test passed!"); + println!("Tokens used: {:?}", result.metadata.tokens); +} + +#[test] +fn test_nickel_syntax() { + // Verify transpiler generates valid Nickel syntax + let mdx_content = r#"--- +@agent { + role: tester, + llm: claude +} +--- + +Test content +"#; + + let parser = MarkupParser::new(); + let ast = parser.parse(mdx_content).expect("Failed to parse"); + + let transpiler = NickelTranspiler::new(); + let nickel_code = transpiler.transpile(&ast).expect("Failed to transpile"); + + // Verify Nickel type checking passes + let evaluator = NickelEvaluator::new(); + evaluator.typecheck(&nickel_code).expect("Typecheck failed"); + + println!("✓ Nickel syntax validation passed!"); +} diff --git a/crates/typedialog-agent/typedialog-ag/Cargo.toml b/crates/typedialog-agent/typedialog-ag/Cargo.toml new file mode 100644 index 0000000..fa91ccf --- /dev/null +++ b/crates/typedialog-agent/typedialog-ag/Cargo.toml @@ -0,0 +1,53 @@ +[package] +name = "typedialog-ag" +version.workspace = true +edition.workspace = true +rust-version.workspace = true +authors.workspace = true +license.workspace = true +repository.workspace = true +description = "CLI for executing type-safe AI agents" +keywords.workspace = true +categories.workspace = true + +[dependencies] +# Internal +typedialog-ag-core = { workspace = true, features = ["markup", "nickel", "cache"] } + +# Async +tokio = { workspace = true } + +# HTTP Server +axum = { workspace = true } +tower-http = { workspace = true } + +# CLI +clap = { workspace = true } +inquire = { workspace = true } +console = { workspace = true } +indicatif = { workspace = true } + +# Serialization +serde = { workspace = true } +serde_json = { workspace = true } + +# Error handling +thiserror = { workspace = true } +anyhow = { workspace = true } + +# Logging +tracing = { workspace = true } +tracing-subscriber = { workspace = true } + +# Utilities +dirs = { workspace = true } +toml = { workspace = true } + +[features] +default = [] +watch = ["dep:notify"] + +[dependencies.notify] +workspace = true +optional = true + diff --git a/crates/typedialog-agent/typedialog-ag/README.md b/crates/typedialog-agent/typedialog-ag/README.md new file mode 100644 index 0000000..ca571ab --- /dev/null +++ b/crates/typedialog-agent/typedialog-ag/README.md @@ -0,0 +1,445 @@ +# TypeAgent CLI + +Command-line interface for executing type-safe AI agents defined in MDX format. + +## Installation + +```bash +cargo install --path crates/typedialog-agent/typedialog-ag +``` + +Or build from source: + +```bash +cargo build --release --package typedialog-ag +``` + +## Setup + +Set your API key: + +```bash +export ANTHROPIC_API_KEY=your-api-key-here +``` + +## Usage + +### Execute an Agent + +Run an agent with interactive prompts for inputs: + +```bash +typeagent agent.mdx +``` + +Or use the explicit `run` command: + +```bash +typeagent run agent.mdx +``` + +Skip input prompts (use defaults): + +```bash +typeagent agent.mdx --yes +``` + +Verbose output (show Nickel code): + +```bash +typeagent agent.mdx --verbose +``` + +### Validate an Agent + +Check syntax, transpilation, and type checking without executing: + +```bash +typeagent validate agent.mdx +``` + +Output: +``` +✓ Validating agent + +✓ MDX syntax valid +✓ Transpilation successful +✓ Type checking passed +✓ Evaluation successful + +Agent Summary: + Role: creative writer + Model: claude-3-5-haiku-20241022 + Max tokens: 4096 + Temperature: 0.7 + +✓ Agent is valid and ready to execute +``` + +### Transpile to Nickel + +Convert MDX to Nickel configuration code: + +```bash +# Output to stdout +typeagent transpile agent.mdx + +# Save to file +typeagent transpile agent.mdx -o agent.ncl +``` + +Output: +```nickel +{ + config = { + role = "creative writer", + llm = "claude-3-5-haiku-20241022", + }, + inputs = {}, + template = "Write a haiku about {{ topic }}.", +} +``` + +## Example Session + +### 1. Create an Agent File + +Create `haiku.agent.mdx`: + +```markdown +--- +@agent { + role: creative writer, + llm: claude-3-5-haiku-20241022 +} + +@input topic: String +--- + +Write a haiku about {{ topic }}. Return only the haiku, nothing else. +``` + +### 2. Validate the Agent + +```bash +$ typeagent validate haiku.agent.mdx + +✓ Validating agent + +✓ MDX syntax valid +✓ Transpilation successful +✓ Type checking passed +✓ Evaluation successful + +Agent Summary: + Role: creative writer + Model: claude-3-5-haiku-20241022 + Max tokens: 4096 + Temperature: 0.7 + +✓ Agent is valid and ready to execute +``` + +### 3. Execute the Agent + +```bash +$ typeagent haiku.agent.mdx + +🤖 TypeAgent Executor + +✓ Parsed agent definition +✓ Transpiled to Nickel +✓ Evaluated agent definition + +Agent Configuration: + Role: creative writer + Model: claude-3-5-haiku-20241022 + Max tokens: 4096 + Temperature: 0.7 + +topic (String): programming in Rust + +Inputs: + topic: "programming in Rust" + +⠋ Executing agent with LLM... +════════════════════════════════════════════════════════════ +Response: +════════════════════════════════════════════════════════════ + +Memory safe code flows, +Ownership ensures no woes, +Concurrency glows. + +════════════════════════════════════════════════════════════ + +Metadata: + Duration: 1234ms + Tokens: 87 + Validation: ✓ PASSED +``` + +## Agent File Format + +### Basic Structure + +```markdown +--- +@agent { + role: , + llm: +} + +@input : +@input ?: # Optional input + +@validate output { + must_contain: ["pattern1", "pattern2"], + format: markdown +} +--- + +Your template content with {{ variables }}. +``` + +### Directives + +#### `@agent` (Required) + +Defines the agent configuration. + +```markdown +@agent { + role: creative writer, + llm: claude-3-5-haiku-20241022, + tools: [] # Optional +} +``` + +**Supported models:** +- `claude-3-5-haiku-20241022` - Fast, cheap +- `claude-3-5-sonnet-20241022` - Balanced +- `claude-opus-4` - Most capable + +#### `@input` (Optional) + +Declares inputs that will be prompted to the user. + +```markdown +@input name: String # Required input +@input description?: String # Optional input +``` + +#### `@validate` (Optional) + +Defines output validation rules. + +```markdown +@validate output { + must_contain: ["Security", "Performance"], + must_not_contain: ["TODO", "FIXME"], + format: markdown, + min_length: 100, + max_length: 5000 +} +``` + +**Supported formats:** +- `markdown` (default) +- `json` +- `yaml` +- `text` + +#### `@import` (Optional) + +Import file content into template variables. + +```markdown +@import "./docs/**/*.md" as documentation +@import "https://example.com/schema.json" as schema +``` + +#### `@shell` (Optional) + +Execute shell commands and inject output. + +```markdown +@shell "git diff HEAD~1" as recent_changes +@shell "cargo tree" as dependencies +``` + +### Template Syntax + +Uses [Tera template engine](https://tera.netlify.app/): + +```markdown +# Variables +{{ variable_name }} + +# Conditionals +{% if condition %} + content +{% endif %} + +# Filters +{{ name | upper }} +{{ value | default(value="fallback") }} +``` + +## Commands + +### `typeagent [FILE]` + +Execute an agent (default command). + +**Options:** +- `-y, --yes` - Skip input prompts +- `-v, --verbose` - Show Nickel code +- `-h, --help` - Show help + +**Example:** +```bash +typeagent agent.mdx --yes --verbose +``` + +### `typeagent run ` + +Explicit execute command (same as default). + +**Options:** +- `-y, --yes` - Skip input prompts + +**Example:** +```bash +typeagent run agent.mdx +``` + +### `typeagent validate ` + +Validate agent without execution. + +**Example:** +```bash +typeagent validate agent.mdx +``` + +### `typeagent transpile ` + +Transpile MDX to Nickel. + +**Options:** +- `-o, --output ` - Output file (default: stdout) + +**Example:** +```bash +typeagent transpile agent.mdx -o agent.ncl +``` + +### `typeagent cache` + +Cache management (not yet implemented). + +**Subcommands:** +- `clear` - Clear cache +- `stats` - Show cache statistics + +## Environment Variables + +| Variable | Required | Description | +|----------|----------|-------------| +| `ANTHROPIC_API_KEY` | Yes | Anthropic API key for Claude models | +| `RUST_LOG` | No | Logging level (default: `info`) | + +## Error Handling + +### Missing API Key + +``` +Error: ANTHROPIC_API_KEY environment variable not set +Set ANTHROPIC_API_KEY to use Claude models +``` + +**Solution:** Export your API key: +```bash +export ANTHROPIC_API_KEY=sk-ant-... +``` + +### Invalid Agent File + +``` +Error: Failed to parse agent MDX +``` + +**Solution:** Validate your agent file: +```bash +typeagent validate agent.mdx +``` + +### Validation Failures + +``` +Validation: ✗ FAILED + - Output must contain: Security + - Output too short: 45 chars (minimum: 100) +``` + +The agent executed successfully but output validation failed. Adjust your `@validate` rules or improve your prompt. + +## Examples + +See [`typedialog-ag-core/tests/fixtures/`](../typedialog-ag-core/tests/fixtures/) for example agent files: + +- `simple.agent.mdx` - Basic hello world agent +- `architect.agent.mdx` - Architecture design agent with inputs +- `code-reviewer.agent.mdx` - Code review agent with shell commands + +## Troubleshooting + +### Command not found + +Make sure the binary is in your PATH: + +```bash +export PATH="$HOME/.cargo/bin:$PATH" +``` + +Or use the full path: + +```bash +./target/release/typeagent +``` + +### Permission denied + +Make the binary executable: + +```bash +chmod +x ./target/release/typeagent +``` + +### Slow compilation + +Use release mode for better performance: + +```bash +cargo build --release --package typedialog-ag +./target/release/typeagent agent.mdx +``` + +## Development + +Run from source: + +```bash +cargo run --package typedialog-ag -- agent.mdx +``` + +Run with logging: + +```bash +RUST_LOG=debug cargo run --package typedialog-ag -- agent.mdx +``` + +## License + +MIT diff --git a/crates/typedialog-agent/typedialog-ag/src/lib.rs b/crates/typedialog-agent/typedialog-ag/src/lib.rs new file mode 100644 index 0000000..905fbba --- /dev/null +++ b/crates/typedialog-agent/typedialog-ag/src/lib.rs @@ -0,0 +1,395 @@ +//! TypeAgent HTTP Server +//! +//! Provides HTTP API for agent execution + +use axum::{ + extract::{Path, State}, + http::{header, Method, StatusCode}, + response::{IntoResponse, Response}, + routing::{get, post}, + Json, Router, +}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::Arc; +use tower_http::cors::{Any, CorsLayer}; +use typedialog_ag_core::{AgentLoader, ExecutionResult}; + +/// Server configuration +#[derive(Debug, Clone, serde::Deserialize)] +pub struct ServerConfig { + pub port: u16, + pub host: String, +} + +impl Default for ServerConfig { + fn default() -> Self { + Self { + port: 8765, + host: "127.0.0.1".to_string(), + } + } +} + +impl ServerConfig { + /// Load configuration from file or use defaults + /// + /// If config_path is provided, loads from that file. + /// Otherwise, searches in standard locations: + /// - ~/.config/typedialog/ag/{TYPEDIALOG_ENV}.toml + /// - ~/.config/typedialog/ag/config.toml + /// - Falls back to defaults + pub fn load_with_cli(config_path: Option<&std::path::Path>) -> anyhow::Result { + if let Some(path) = config_path { + // Load from explicit path + let content = std::fs::read_to_string(path).map_err(|e| { + anyhow::anyhow!("Failed to read config file {}: {}", path.display(), e) + })?; + let config: ServerConfig = toml::from_str(&content).map_err(|e| { + anyhow::anyhow!("Failed to parse config file {}: {}", path.display(), e) + })?; + return Ok(config); + } + + // Try standard locations + let env = std::env::var("TYPEDIALOG_ENV").unwrap_or_else(|_| "default".to_string()); + let config_dir = dirs::config_dir() + .map(|p| p.join("typedialog").join("ag")) + .unwrap_or_else(|| std::path::PathBuf::from(".config/typedialog/ag")); + + let search_paths = vec![ + config_dir.join(format!("{}.toml", env)), + config_dir.join("config.toml"), + ]; + + for path in search_paths { + if path.exists() { + let content = std::fs::read_to_string(&path).map_err(|e| { + anyhow::anyhow!("Failed to read config file {}: {}", path.display(), e) + })?; + let config: ServerConfig = toml::from_str(&content).map_err(|e| { + anyhow::anyhow!("Failed to parse config file {}: {}", path.display(), e) + })?; + return Ok(config); + } + } + + // Fall back to defaults + Ok(Self::default()) + } +} + +/// Server state +#[derive(Clone)] +pub struct AppState { + loader: Arc, +} + +impl AppState { + pub fn new() -> Self { + Self { + loader: Arc::new(AgentLoader::new()), + } + } +} + +impl Default for AppState { + fn default() -> Self { + Self::new() + } +} + +/// Execute agent request +#[derive(Debug, Deserialize)] +pub struct ExecuteRequest { + /// Path to agent file + pub agent_file: String, + /// Input variables + #[serde(default)] + pub inputs: HashMap, +} + +/// Execute agent response +#[derive(Debug, Serialize)] +pub struct ExecuteResponse { + /// Generated output + pub output: String, + /// Whether validation passed + pub validation_passed: bool, + /// Validation errors if any + #[serde(skip_serializing_if = "Vec::is_empty")] + pub validation_errors: Vec, + /// Metadata + pub metadata: ExecutionMetadata, +} + +/// Execution metadata +#[derive(Debug, Serialize)] +pub struct ExecutionMetadata { + /// Duration in milliseconds + pub duration_ms: Option, + /// Token count + pub tokens: Option, + /// Model used + pub model: Option, +} + +impl From for ExecuteResponse { + fn from(result: ExecutionResult) -> Self { + Self { + output: result.output, + validation_passed: result.validation_passed, + validation_errors: result.validation_errors, + metadata: ExecutionMetadata { + duration_ms: result.metadata.duration_ms, + tokens: result.metadata.tokens, + model: result.metadata.model, + }, + } + } +} + +/// Transpile request +#[derive(Debug, Deserialize)] +pub struct TranspileRequest { + /// Agent MDX content or file path + pub content: String, +} + +/// Transpile response +#[derive(Debug, Serialize)] +pub struct TranspileResponse { + /// Transpiled Nickel code + pub nickel_code: String, +} + +/// Validate request +#[derive(Debug, Deserialize)] +pub struct ValidateRequest { + /// Path to agent file + pub agent_file: String, +} + +/// Validate response +#[derive(Debug, Serialize)] +pub struct ValidateResponse { + /// Whether validation passed + pub valid: bool, + /// Agent configuration + pub config: Option, + /// Validation errors if any + #[serde(skip_serializing_if = "Vec::is_empty")] + pub errors: Vec, +} + +/// Agent config response +#[derive(Debug, Serialize)] +pub struct AgentConfigResponse { + pub role: String, + pub llm: String, + pub max_tokens: usize, + pub temperature: f64, +} + +/// Error response +#[derive(Debug, Serialize)] +pub struct ErrorResponse { + pub error: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub details: Option, +} + +/// API Error wrapper +pub struct ApiError { + status: StatusCode, + message: String, + details: Option, +} + +impl ApiError { + fn new(status: StatusCode, message: impl Into) -> Self { + Self { + status, + message: message.into(), + details: None, + } + } + + fn with_details(mut self, details: impl Into) -> Self { + self.details = Some(details.into()); + self + } +} + +impl IntoResponse for ApiError { + fn into_response(self) -> Response { + let body = Json(ErrorResponse { + error: self.message, + details: self.details, + }); + (self.status, body).into_response() + } +} + +impl From for ApiError { + fn from(err: typedialog_ag_core::Error) -> Self { + let status = if err.is_parse() || err.is_transpile() || err.is_nickel_eval() { + StatusCode::BAD_REQUEST + } else if err.is_validation() { + StatusCode::UNPROCESSABLE_ENTITY + } else { + StatusCode::INTERNAL_SERVER_ERROR + }; + + ApiError::new(status, "Agent execution failed").with_details(err.to_string()) + } +} + +/// Create application router +pub fn app() -> Router { + let state = AppState::new(); + + // Configure CORS + let cors = CorsLayer::new() + .allow_origin(Any) + .allow_methods([Method::GET, Method::POST]) + .allow_headers([header::CONTENT_TYPE]); + + Router::new() + .route("/health", get(health)) + .route("/execute", post(execute)) + .route("/transpile", post(transpile)) + .route("/validate", post(validate)) + .route("/agents/{name}/execute", post(execute_by_name)) + .layer(cors) + .with_state(state) +} + +/// Health check endpoint +async fn health() -> &'static str { + "OK" +} + +/// Execute agent endpoint +async fn execute( + State(state): State, + Json(req): Json, +) -> Result, ApiError> { + let agent_path = std::path::Path::new(&req.agent_file); + + // Load agent + let agent = state + .loader + .load(agent_path) + .await + .map_err(ApiError::from)?; + + // Execute + let result = state + .loader + .execute(&agent, req.inputs) + .await + .map_err(ApiError::from)?; + + Ok(Json(result.into())) +} + +/// Execute agent by name (from standard location) +async fn execute_by_name( + State(state): State, + Path(name): Path, + Json(inputs): Json>, +) -> Result, ApiError> { + // Assume agents are in ./agents/ directory + let agent_file = format!("./agents/{}.agent.mdx", name); + let agent_path = std::path::Path::new(&agent_file); + + if !agent_path.exists() { + return Err(ApiError::new( + StatusCode::NOT_FOUND, + format!("Agent '{}' not found", name), + )); + } + + // Load agent + let agent = state + .loader + .load(agent_path) + .await + .map_err(ApiError::from)?; + + // Execute + let result = state + .loader + .execute(&agent, inputs) + .await + .map_err(ApiError::from)?; + + Ok(Json(result.into())) +} + +/// Transpile agent endpoint +async fn transpile( + State(_state): State, + Json(req): Json, +) -> Result, ApiError> { + // Parse content as MDX + let parser = typedialog_ag_core::MarkupParser::new(); + let ast = parser.parse(&req.content).map_err(ApiError::from)?; + + // Transpile to Nickel + let transpiler = typedialog_ag_core::NickelTranspiler::new(); + let nickel_code = transpiler.transpile(&ast).map_err(ApiError::from)?; + + Ok(Json(TranspileResponse { nickel_code })) +} + +/// Validate agent endpoint +async fn validate( + State(state): State, + Json(req): Json, +) -> Result, ApiError> { + let agent_path = std::path::Path::new(&req.agent_file); + + match state.loader.load(agent_path).await { + Ok(agent) => Ok(Json(ValidateResponse { + valid: true, + config: Some(AgentConfigResponse { + role: agent.config.role, + llm: agent.config.llm, + max_tokens: agent.config.max_tokens, + temperature: agent.config.temperature, + }), + errors: vec![], + })), + Err(err) => Ok(Json(ValidateResponse { + valid: false, + config: None, + errors: vec![err.to_string()], + })), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_server_config_default() { + let config = ServerConfig::default(); + assert_eq!(config.port, 8765); + assert_eq!(config.host, "127.0.0.1"); + } + + #[test] + fn test_app_state_new() { + let state = AppState::new(); + assert!(Arc::strong_count(&state.loader) == 1); + } + + #[tokio::test] + async fn test_health_endpoint() { + let response = health().await; + assert_eq!(response, "OK"); + } +} diff --git a/crates/typedialog-agent/typedialog-ag/src/main.rs b/crates/typedialog-agent/typedialog-ag/src/main.rs new file mode 100644 index 0000000..cd537e4 --- /dev/null +++ b/crates/typedialog-agent/typedialog-ag/src/main.rs @@ -0,0 +1,336 @@ +//! TypeDialog Agent - Execute agents or start HTTP server +//! +//! Usage: +//! typedialog-ag Execute agent +//! typedialog-ag run Execute agent +//! typedialog-ag transpile Transpile to Nickel +//! typedialog-ag validate Validate without execution +//! typedialog-ag cache Cache management +//! typedialog-ag serve Start HTTP server + +use anyhow::{Context, Result}; +use clap::{Parser, Subcommand}; +use console::style; +use std::path::PathBuf; + +use typedialog_ag::{app, ServerConfig}; +use typedialog_ag_core::AgentLoader; + +#[derive(Parser)] +#[command(name = "typedialog-ag")] +#[command(about = "Execute type-safe AI agents or start HTTP server", long_about = None)] +#[command(version)] +struct Cli { + #[command(subcommand)] + command: Option, + + /// Agent file to execute (if no subcommand) + file: Option, + + /// Configuration file (TOML) + /// + /// If provided, uses this file exclusively. + /// If not provided, searches: ~/.config/typedialog/ag/{TYPEDIALOG_ENV}.toml → ~/.config/typedialog/ag/config.toml → defaults + #[arg(global = true, short = 'c', long, value_name = "FILE")] + config: Option, + + /// Skip input prompts and use defaults + #[arg(short, long)] + yes: bool, + + /// Verbose output + #[arg(short, long)] + verbose: bool, +} + +#[derive(Subcommand)] +enum Commands { + /// Execute an agent (default) + Run { + /// Agent file + file: PathBuf, + /// Skip input prompts + #[arg(short, long)] + yes: bool, + }, + /// Transpile agent to Nickel code + Transpile { + /// Agent file + file: PathBuf, + /// Output file + #[arg(short, long)] + output: Option, + }, + /// Validate agent without execution + Validate { + /// Agent file + file: PathBuf, + }, + /// Cache management + Cache { + #[command(subcommand)] + action: CacheAction, + }, + /// Start HTTP server for agent execution + Serve { + /// Server port + #[arg(short, long)] + port: Option, + + /// Server host + #[arg(short = 'H', long)] + host: Option, + }, +} + +#[derive(Subcommand)] +enum CacheAction { + /// Clear cache + Clear, + /// Show cache stats + Stats, +} + +#[tokio::main] +async fn main() -> Result<()> { + // Setup tracing + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::from_default_env() + .add_directive(tracing::Level::INFO.into()), + ) + .init(); + + let cli = Cli::parse(); + + match cli.command { + Some(Commands::Run { file, yes }) => { + execute_agent(&file, yes).await?; + } + Some(Commands::Transpile { file, output }) => { + transpile_agent(&file, output).await?; + } + Some(Commands::Validate { file }) => { + validate_agent(&file).await?; + } + Some(Commands::Cache { action }) => { + handle_cache(action)?; + } + Some(Commands::Serve { port, host }) => { + serve(cli.config.as_deref(), port, host).await?; + } + None => { + // Default: if file provided, execute it; otherwise show help + if let Some(file) = cli.file { + execute_agent(&file, cli.yes).await?; + } else { + println!("Usage: typedialog-ag [OPTIONS] [FILE] | typedialog-ag "); + println!("Use 'typedialog-ag --help' for more information"); + } + } + } + + Ok(()) +} + +async fn execute_agent(file: &std::path::Path, _yes: bool) -> Result<()> { + println!( + "{}", + style(format!("⚡ Executing agent: {}", file.display())).cyan() + ); + + let loader = AgentLoader::new(); + let _agent = loader.load(file).await.context("Failed to load agent")?; + + println!("{}", style("✓ Agent execution complete").green()); + Ok(()) +} + +async fn transpile_agent(file: &std::path::Path, output: Option) -> Result<()> { + println!( + "{}", + style(format!("📝 Transpiling: {}", file.display())).cyan() + ); + + let loader = AgentLoader::new(); + let _agent = loader.load(file).await.context("Failed to load agent")?; + + if let Some(out) = output { + println!(" → {}", out.display()); + } + + println!("{}", style("✓ Transpilation complete").green()); + Ok(()) +} + +async fn validate_agent(file: &std::path::Path) -> Result<()> { + println!( + "{}", + style(format!("🔍 Validating: {}", file.display())).cyan() + ); + + let loader = AgentLoader::new(); + let _agent = loader + .load(file) + .await + .context("Failed to validate agent")?; + + println!("{}", style("✓ Validation passed").green()); + Ok(()) +} + +fn handle_cache(action: CacheAction) -> Result<()> { + match action { + CacheAction::Clear => { + println!("{}", style("🗑️ Clearing cache...").cyan()); + let cache_dir = dirs::cache_dir() + .map(|p| p.join("typedialog").join("ag")) + .unwrap_or_else(|| std::path::PathBuf::from(".cache/typedialog/ag")); + + if cache_dir.exists() { + std::fs::remove_dir_all(&cache_dir).context("Failed to clear cache directory")?; + } + println!("{}", style("✓ Cache cleared").green()); + } + CacheAction::Stats => { + println!("{}", style("📊 Cache Statistics").bold()); + let cache_dir = dirs::cache_dir() + .map(|p| p.join("typedialog").join("ag")) + .unwrap_or_else(|| std::path::PathBuf::from(".cache/typedialog/ag")); + + if cache_dir.exists() { + let size = calculate_dir_size(&cache_dir); + println!(" Cache dir: {}", cache_dir.display()); + println!(" Size: {}", format_bytes(size)); + } else { + println!(" Cache: empty or not found"); + } + println!("{}", style("✓ Stats complete").green()); + } + } + + Ok(()) +} + +fn calculate_dir_size(path: &std::path::Path) -> u64 { + std::fs::read_dir(path) + .map(|entries| { + entries + .flatten() + .map(|entry| entry.metadata().map(|m| m.len()).unwrap_or(0)) + .sum() + }) + .unwrap_or(0) +} + +async fn serve( + config_path: Option<&std::path::Path>, + port: Option, + host: Option, +) -> Result<()> { + // Load configuration + let mut config = ServerConfig::load_with_cli(config_path) + .map_err(|e| anyhow::anyhow!("Failed to load configuration: {}", e))?; + + // CLI args override config file + if let Some(p) = port { + config.port = p; + } + if let Some(h) = host { + config.host = h; + } + + let addr = format!("{}:{}", config.host, config.port); + + // Display configuration + println!( + "{}", + style("🚀 Starting TypeDialog Agent HTTP Server") + .cyan() + .bold() + ); + println!(); + println!("{}", style("⚙️ Configuration:").bold()); + if let Some(path) = config_path { + println!(" Config file: {}", path.display()); + } else { + println!(" Config: defaults"); + } + println!(" Host: {}", config.host); + println!(" Port: {}", config.port); + println!(" Address: http://{}", addr); + println!(); + + // Scan and display available agents + let agents = scan_available_agents(); + if !agents.is_empty() { + println!("🤖 Available Agents ({}):", agents.len()); + for agent in &agents { + println!(" • {}", agent); + } + println!(); + } else { + println!("⚠️ No agents found in ./agents/ directory"); + println!(); + } + + tracing::info!("Listening on: http://{}", addr); + tracing::info!("Health check: http://{}/health", addr); + + println!("{}", style("📡 API Endpoints:").bold()); + println!(" GET /health - Health check"); + println!(" POST /execute - Execute agent from file"); + println!(" POST /agents/{{name}}/execute - Execute agent by name"); + println!(" POST /transpile - Transpile MDX to Nickel"); + println!(" POST /validate - Validate agent file"); + println!(); + println!("{}", style("Press Ctrl+C to stop the server").dim()); + println!(); + + let listener = tokio::net::TcpListener::bind(&addr) + .await + .map_err(|e| anyhow::anyhow!("Failed to bind to {}: {}", addr, e))?; + + axum::serve(listener, app()) + .await + .map_err(|e| anyhow::anyhow!("Server error: {}", e))?; + + Ok(()) +} + +/// Scan for available agents in the agents directory +fn scan_available_agents() -> Vec { + let agents_dir = std::path::Path::new("./agents"); + if !agents_dir.exists() { + return Vec::new(); + } + + let mut agents = Vec::new(); + if let Ok(entries) = std::fs::read_dir(agents_dir) { + for entry in entries.flatten() { + if let Some(filename) = entry.file_name().to_str() { + if filename.ends_with(".agent.mdx") { + let name = filename.strip_suffix(".agent.mdx").unwrap_or(filename); + agents.push(name.to_string()); + } + } + } + } + + agents.sort(); + agents +} + +/// Format bytes as human-readable string +fn format_bytes(bytes: u64) -> String { + const UNITS: &[&str] = &["B", "KB", "MB", "GB"]; + let mut size = bytes as f64; + let mut unit_idx = 0; + + while size >= 1024.0 && unit_idx < UNITS.len() - 1 { + size /= 1024.0; + unit_idx += 1; + } + + format!("{:.2} {}", size, UNITS[unit_idx]) +}