Skip to main content

Overview

This example demonstrates how to build a complete, production-ready REST API for your RAG system using FastAPI. It includes authentication, file uploads, health checks, and comprehensive error handling.

Features

RESTful API

Complete REST API with async support

File Upload

Upload and index documents via API

Authentication

JWT-based authentication

Monitoring

Health checks and metrics

Complete Implementation

import os
import uuid
import shutil
from pathlib import Path
from typing import Optional, List
from datetime import datetime, timedelta

from fastapi import (
    FastAPI,
    File,
    UploadFile,
    HTTPException,
    Depends,
    status,
    Security
)
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field, validator
import jwt
from mini import (
    AgenticRAG,
    LLMConfig,
    RetrievalConfig,
    ObservabilityConfig,
    EmbeddingModel,
    VectorStore
)
from dotenv import load_dotenv

load_dotenv()

# Configuration
JWT_SECRET = os.getenv("JWT_SECRET", "your-secret-key-change-in-production")
JWT_ALGORITHM = "HS256"
JWT_EXPIRATION_HOURS = 24

UPLOAD_DIR = Path("./uploads")
UPLOAD_DIR.mkdir(exist_ok=True)

ALLOWED_EXTENSIONS = {".pdf", ".docx", ".txt", ".md"}
MAX_FILE_SIZE = 10 * 1024 * 1024  # 10MB

# FastAPI app
app = FastAPI(
    title="Mini RAG API",
    description="Production-ready RAG API with Mini RAG",
    version="1.0.0"
)

# CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Configure appropriately in production
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Security
security = HTTPBearer()

# Global RAG instance
rag: Optional[AgenticRAG] = None

# Request/Response Models
class QueryRequest(BaseModel):
    """Request model for querying."""
    question: str = Field(..., min_length=1, max_length=500)
    top_k: int = Field(default=10, ge=1, le=50)
    rerank_top_k: int = Field(default=3, ge=1, le=10)
    include_sources: bool = Field(default=True)
    
    @validator('question')
    def validate_question(cls, v):
        """Validate and sanitize question."""
        # Remove dangerous characters
        dangerous_chars = ['<', '>', '{', '}', '\\x00']
        for char in dangerous_chars:
            if char in v:
                raise ValueError(f"Invalid character: {char}")
        return v.strip()

class QueryResponse(BaseModel):
    """Response model for queries."""
    answer: str
    sources: List[dict]
    metadata: dict
    query_id: str
    timestamp: datetime

class DocumentUploadResponse(BaseModel):
    """Response model for document uploads."""
    filename: str
    document_id: str
    chunks: int
    status: str
    timestamp: datetime

class HealthResponse(BaseModel):
    """Response model for health checks."""
    status: str
    timestamp: datetime
    rag_initialized: bool
    total_documents: Optional[int] = None

class StatsResponse(BaseModel):
    """Response model for statistics."""
    total_documents: int
    collection_name: str
    timestamp: datetime

class LoginRequest(BaseModel):
    """Request model for login."""
    username: str
    password: str

class TokenResponse(BaseModel):
    """Response model for authentication."""
    access_token: str
    token_type: str
    expires_in: int

# Authentication
def create_access_token(user_id: str) -> str:
    """Create JWT access token."""
    payload = {
        "user_id": user_id,
        "exp": datetime.utcnow() + timedelta(hours=JWT_EXPIRATION_HOURS)
    }
    return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)

def verify_token(
    credentials: HTTPAuthorizationCredentials = Security(security)
) -> str:
    """Verify JWT token and return user ID."""
    try:
        token = credentials.credentials
        payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
        user_id = payload.get("user_id")
        
        if user_id is None:
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="Invalid token"
            )
        
        return user_id
        
    except jwt.ExpiredSignatureError:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Token has expired"
        )
    except jwt.InvalidTokenError:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Invalid token"
        )

# Startup/Shutdown Events
@app.on_event("startup")
async def startup_event():
    """Initialize RAG system on startup."""
    global rag
    
    try:
        print("🚀 Initializing Mini RAG...")
        
        # Initialize embedding model
        embedding_model = EmbeddingModel()
        
        # Initialize vector store
        vector_store = VectorStore(
            uri=os.getenv("MILVUS_URI"),
            token=os.getenv("MILVUS_TOKEN"),
            collection_name=os.getenv("COLLECTION_NAME", "api_documents"),
            dimension=1536
        )
        
        # Initialize RAG
        rag = AgenticRAG(
            vector_store=vector_store,
            embedding_model=embedding_model,
            llm_config=LLMConfig(
                model=os.getenv("LLM_MODEL", "gpt-4o-mini"),
                temperature=0.7
            ),
            retrieval_config=RetrievalConfig(
                top_k=10,
                rerank_top_k=3,
                use_query_rewriting=True,
                use_reranking=True,
                use_hybrid_search=True
            ),
            observability_config=ObservabilityConfig(
                enabled=os.getenv("ENABLE_OBSERVABILITY", "false").lower() == "true"
            )
        )
        
        stats = rag.get_stats()
        print(f"✅ Mini RAG initialized successfully")
        print(f"📊 Total documents: {stats['total_documents']}")
        
    except Exception as e:
        print(f"❌ Failed to initialize Mini RAG: {e}")
        raise

@app.on_event("shutdown")
async def shutdown_event():
    """Clean up resources on shutdown."""
    global rag
    
    if rag:
        try:
            rag.vector_store.disconnect()
            print("✅ Mini RAG shut down successfully")
        except Exception as e:
            print(f"⚠️  Error during shutdown: {e}")

# API Endpoints

@app.post("/auth/login", response_model=TokenResponse)
async def login(request: LoginRequest):
    """
    Authenticate and receive an access token.
    
    In production, validate credentials against a database.
    """
    # TODO: Implement proper user authentication
    # This is a simplified example
    if request.username == "demo" and request.password == "demo123":
        token = create_access_token(request.username)
        return TokenResponse(
            access_token=token,
            token_type="bearer",
            expires_in=JWT_EXPIRATION_HOURS * 3600
        )
    
    raise HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="Invalid credentials"
    )

@app.post("/query", response_model=QueryResponse)
async def query(
    request: QueryRequest,
    user_id: str = Depends(verify_token)
):
    """
    Query the RAG system with a question.
    
    Requires authentication.
    """
    if not rag:
        raise HTTPException(
            status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
            detail="RAG system not initialized"
        )
    
    try:
        # Query the RAG system
        response = rag.query(
            query=request.question,
            top_k=request.top_k,
            rerank_top_k=request.rerank_top_k,
            return_sources=request.include_sources
        )
        
        # Format sources
        sources = []
        if request.include_sources:
            sources = [
                {
                    "text": chunk.text[:200] + "...",
                    "score": chunk.reranked_score or chunk.score,
                    "metadata": chunk.metadata
                }
                for chunk in response.retrieved_chunks
            ]
        
        return QueryResponse(
            answer=response.answer,
            sources=sources,
            metadata=response.metadata,
            query_id=str(uuid.uuid4()),
            timestamp=datetime.utcnow()
        )
        
    except Exception as e:
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=f"Query failed: {str(e)}"
        )

@app.post("/documents/upload", response_model=DocumentUploadResponse)
async def upload_document(
    file: UploadFile = File(...),
    user_id: str = Depends(verify_token)
):
    """
    Upload and index a document.
    
    Requires authentication.
    """
    if not rag:
        raise HTTPException(
            status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
            detail="RAG system not initialized"
        )
    
    # Validate file extension
    file_ext = Path(file.filename).suffix.lower()
    if file_ext not in ALLOWED_EXTENSIONS:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=f"File type not allowed. Allowed: {ALLOWED_EXTENSIONS}"
        )
    
    # Generate unique filename
    document_id = str(uuid.uuid4())
    filename = f"{document_id}_{file.filename}"
    file_path = UPLOAD_DIR / filename
    
    try:
        # Save uploaded file
        with file_path.open("wb") as buffer:
            shutil.copyfileobj(file.file, buffer)
        
        # Check file size
        file_size = file_path.stat().st_size
        if file_size > MAX_FILE_SIZE:
            file_path.unlink()
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=f"File too large. Max size: {MAX_FILE_SIZE / 1024 / 1024}MB"
            )
        
        # Index the document
        num_chunks = rag.index_document(
            str(file_path),
            metadata={
                "document_id": document_id,
                "filename": file.filename,
                "uploaded_by": user_id,
                "uploaded_at": datetime.utcnow().isoformat(),
                "file_size": file_size
            }
        )
        
        return DocumentUploadResponse(
            filename=file.filename,
            document_id=document_id,
            chunks=num_chunks,
            status="indexed",
            timestamp=datetime.utcnow()
        )
        
    except Exception as e:
        # Clean up on error
        if file_path.exists():
            file_path.unlink()
        
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=f"Upload failed: {str(e)}"
        )

@app.delete("/documents/{document_id}")
async def delete_document(
    document_id: str,
    user_id: str = Depends(verify_token)
):
    """
    Delete a document.
    
    Requires authentication.
    """
    if not rag:
        raise HTTPException(
            status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
            detail="RAG system not initialized"
        )
    
    try:
        # Delete from vector store
        deleted = rag.vector_store.delete(f'metadata["document_id"] == "{document_id}"')
        
        if deleted == 0:
            raise HTTPException(
                status_code=status.HTTP_404_NOT_FOUND,
                detail="Document not found"
            )
        
        # Delete file if exists
        for file_path in UPLOAD_DIR.glob(f"{document_id}_*"):
            file_path.unlink()
        
        return {
            "document_id": document_id,
            "status": "deleted",
            "chunks_deleted": deleted
        }
        
    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=f"Deletion failed: {str(e)}"
        )

@app.get("/health", response_model=HealthResponse)
async def health_check():
    """Health check endpoint."""
    is_healthy = rag is not None
    total_docs = None
    
    if is_healthy:
        try:
            stats = rag.get_stats()
            total_docs = stats['total_documents']
        except:
            is_healthy = False
    
    return HealthResponse(
        status="healthy" if is_healthy else "unhealthy",
        timestamp=datetime.utcnow(),
        rag_initialized=is_healthy,
        total_documents=total_docs
    )

@app.get("/stats", response_model=StatsResponse)
async def get_stats(user_id: str = Depends(verify_token)):
    """
    Get system statistics.
    
    Requires authentication.
    """
    if not rag:
        raise HTTPException(
            status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
            detail="RAG system not initialized"
        )
    
    try:
        stats = rag.get_stats()
        return StatsResponse(
            total_documents=stats['total_documents'],
            collection_name=stats['collection_name'],
            timestamp=datetime.utcnow()
        )
    except Exception as e:
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=f"Failed to get stats: {str(e)}"
        )

@app.get("/")
async def root():
    """Root endpoint."""
    return {
        "name": "Mini RAG API",
        "version": "1.0.0",
        "status": "running",
        "docs": "/docs"
    }

# Error handlers
@app.exception_handler(HTTPException)
async def http_exception_handler(request, exc):
    """Custom HTTP exception handler."""
    return {
        "error": exc.detail,
        "status_code": exc.status_code,
        "timestamp": datetime.utcnow().isoformat()
    }

if __name__ == "__main__":
    import uvicorn
    
    uvicorn.run(
        app,
        host="0.0.0.0",
        port=int(os.getenv("PORT", 8000)),
        log_level="info"
    )

Running the API

1

Install Dependencies

uv add mini-rag fastapi uvicorn python-multipart python-jose
2

Set Environment Variables

Create a .env file:
OPENAI_API_KEY=sk-...
MILVUS_URI=https://...
MILVUS_TOKEN=...
JWT_SECRET=your-secret-key-change-in-production
COLLECTION_NAME=api_documents
LLM_MODEL=gpt-4o-mini
ENABLE_OBSERVABILITY=false
PORT=8000
3

Run the Server

python main.py
Or with uvicorn directly:
uvicorn main:app --host 0.0.0.0 --port 8000 --reload
4

Access API Documentation

Navigate to http://localhost:8000/docs for interactive API docs

API Usage Examples

1. Authentication

# Login to get access token
curl -X POST "http://localhost:8000/auth/login" \
  -H "Content-Type: application/json" \
  -d '{
    "username": "demo",
    "password": "demo123"
  }'

# Response
{
  "access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
  "token_type": "bearer",
  "expires_in": 86400
}

2. Upload Document

# Upload a document
curl -X POST "http://localhost:8000/documents/upload" \
  -H "Authorization: Bearer YOUR_TOKEN" \
  -F "file=@document.pdf"

# Response
{
  "filename": "document.pdf",
  "document_id": "123e4567-e89b-12d3-a456-426614174000",
  "chunks": 45,
  "status": "indexed",
  "timestamp": "2024-01-15T10:30:00"
}

3. Query

# Query the system
curl -X POST "http://localhost:8000/query" \
  -H "Authorization: Bearer YOUR_TOKEN" \
  -H "Content-Type: application/json" \
  -d '{
    "question": "What is the main topic?",
    "top_k": 10,
    "rerank_top_k": 3,
    "include_sources": true
  }'

# Response
{
  "answer": "The main topic is...",
  "sources": [
    {
      "text": "...",
      "score": 0.95,
      "metadata": {...}
    }
  ],
  "metadata": {...},
  "query_id": "...",
  "timestamp": "2024-01-15T10:31:00"
}

4. Get Statistics

# Get system stats
curl -X GET "http://localhost:8000/stats" \
  -H "Authorization: Bearer YOUR_TOKEN"

# Response
{
  "total_documents": 150,
  "collection_name": "api_documents",
  "timestamp": "2024-01-15T10:32:00"
}

5. Health Check

# Check system health
curl -X GET "http://localhost:8000/health"

# Response
{
  "status": "healthy",
  "timestamp": "2024-01-15T10:33:00",
  "rag_initialized": true,
  "total_documents": 150
}

Python Client

Create a Python client for your API:
import requests
from typing import Optional, List

class MiniRAGClient:
    """Python client for Mini RAG API."""
    
    def __init__(self, base_url: str, username: str, password: str):
        self.base_url = base_url.rstrip('/')
        self.token = None
        self.login(username, password)
    
    def login(self, username: str, password: str):
        """Authenticate and get access token."""
        response = requests.post(
            f"{self.base_url}/auth/login",
            json={"username": username, "password": password}
        )
        response.raise_for_status()
        self.token = response.json()["access_token"]
    
    def _headers(self):
        """Get authentication headers."""
        return {"Authorization": f"Bearer {self.token}"}
    
    def query(
        self,
        question: str,
        top_k: int = 10,
        rerank_top_k: int = 3,
        include_sources: bool = True
    ):
        """Query the RAG system."""
        response = requests.post(
            f"{self.base_url}/query",
            json={
                "question": question,
                "top_k": top_k,
                "rerank_top_k": rerank_top_k,
                "include_sources": include_sources
            },
            headers=self._headers()
        )
        response.raise_for_status()
        return response.json()
    
    def upload_document(self, file_path: str):
        """Upload and index a document."""
        with open(file_path, 'rb') as f:
            files = {'file': f}
            response = requests.post(
                f"{self.base_url}/documents/upload",
                files=files,
                headers=self._headers()
            )
        response.raise_for_status()
        return response.json()
    
    def delete_document(self, document_id: str):
        """Delete a document."""
        response = requests.delete(
            f"{self.base_url}/documents/{document_id}",
            headers=self._headers()
        )
        response.raise_for_status()
        return response.json()
    
    def get_stats(self):
        """Get system statistics."""
        response = requests.get(
            f"{self.base_url}/stats",
            headers=self._headers()
        )
        response.raise_for_status()
        return response.json()
    
    def health_check(self):
        """Check system health."""
        response = requests.get(f"{self.base_url}/health")
        response.raise_for_status()
        return response.json()

# Usage
client = MiniRAGClient(
    base_url="http://localhost:8000",
    username="demo",
    password="demo123"
)

# Upload document
result = client.upload_document("document.pdf")
print(f"Uploaded: {result['filename']}, {result['chunks']} chunks")

# Query
response = client.query("What is this about?")
print(f"Answer: {response['answer']}")

# Get stats
stats = client.get_stats()
print(f"Total documents: {stats['total_documents']}")

Docker Deployment

Create a Dockerfile:
FROM python:3.11-slim

WORKDIR /app

# Install dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

# Copy application
COPY . .

# Create uploads directory
RUN mkdir -p uploads

# Create non-root user
RUN useradd -m -u 1000 appuser && chown -R appuser:appuser /app
USER appuser

# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
    CMD curl -f http://localhost:8000/health || exit 1

# Run application
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
Build and run:
# Build
docker build -t minirag-api .

# Run
docker run -p 8000:8000 \
  -e OPENAI_API_KEY=sk-... \
  -e MILVUS_URI=https://... \
  -e MILVUS_TOKEN=... \
  minirag-api

Next Steps