Overview
This example demonstrates how to build a complete, production-ready REST API for your RAG system using FastAPI. It includes authentication, file uploads, health checks, and comprehensive error handling.Features
RESTful API
Complete REST API with async support
File Upload
Upload and index documents via API
Authentication
JWT-based authentication
Monitoring
Health checks and metrics
Complete Implementation
import os
import uuid
import shutil
from pathlib import Path
from typing import Optional, List
from datetime import datetime, timedelta
from fastapi import (
FastAPI,
File,
UploadFile,
HTTPException,
Depends,
status,
Security
)
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field, validator
import jwt
from mini import (
AgenticRAG,
LLMConfig,
RetrievalConfig,
ObservabilityConfig,
EmbeddingModel,
VectorStore
)
from dotenv import load_dotenv
load_dotenv()
# Configuration
JWT_SECRET = os.getenv("JWT_SECRET", "your-secret-key-change-in-production")
JWT_ALGORITHM = "HS256"
JWT_EXPIRATION_HOURS = 24
UPLOAD_DIR = Path("./uploads")
UPLOAD_DIR.mkdir(exist_ok=True)
ALLOWED_EXTENSIONS = {".pdf", ".docx", ".txt", ".md"}
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
# FastAPI app
app = FastAPI(
title="Mini RAG API",
description="Production-ready RAG API with Mini RAG",
version="1.0.0"
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Configure appropriately in production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Security
security = HTTPBearer()
# Global RAG instance
rag: Optional[AgenticRAG] = None
# Request/Response Models
class QueryRequest(BaseModel):
"""Request model for querying."""
question: str = Field(..., min_length=1, max_length=500)
top_k: int = Field(default=10, ge=1, le=50)
rerank_top_k: int = Field(default=3, ge=1, le=10)
include_sources: bool = Field(default=True)
@validator('question')
def validate_question(cls, v):
"""Validate and sanitize question."""
# Remove dangerous characters
dangerous_chars = ['<', '>', '{', '}', '\\x00']
for char in dangerous_chars:
if char in v:
raise ValueError(f"Invalid character: {char}")
return v.strip()
class QueryResponse(BaseModel):
"""Response model for queries."""
answer: str
sources: List[dict]
metadata: dict
query_id: str
timestamp: datetime
class DocumentUploadResponse(BaseModel):
"""Response model for document uploads."""
filename: str
document_id: str
chunks: int
status: str
timestamp: datetime
class HealthResponse(BaseModel):
"""Response model for health checks."""
status: str
timestamp: datetime
rag_initialized: bool
total_documents: Optional[int] = None
class StatsResponse(BaseModel):
"""Response model for statistics."""
total_documents: int
collection_name: str
timestamp: datetime
class LoginRequest(BaseModel):
"""Request model for login."""
username: str
password: str
class TokenResponse(BaseModel):
"""Response model for authentication."""
access_token: str
token_type: str
expires_in: int
# Authentication
def create_access_token(user_id: str) -> str:
"""Create JWT access token."""
payload = {
"user_id": user_id,
"exp": datetime.utcnow() + timedelta(hours=JWT_EXPIRATION_HOURS)
}
return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
def verify_token(
credentials: HTTPAuthorizationCredentials = Security(security)
) -> str:
"""Verify JWT token and return user ID."""
try:
token = credentials.credentials
payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
user_id = payload.get("user_id")
if user_id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token"
)
return user_id
except jwt.ExpiredSignatureError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token has expired"
)
except jwt.InvalidTokenError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token"
)
# Startup/Shutdown Events
@app.on_event("startup")
async def startup_event():
"""Initialize RAG system on startup."""
global rag
try:
print("🚀 Initializing Mini RAG...")
# Initialize embedding model
embedding_model = EmbeddingModel()
# Initialize vector store
vector_store = VectorStore(
uri=os.getenv("MILVUS_URI"),
token=os.getenv("MILVUS_TOKEN"),
collection_name=os.getenv("COLLECTION_NAME", "api_documents"),
dimension=1536
)
# Initialize RAG
rag = AgenticRAG(
vector_store=vector_store,
embedding_model=embedding_model,
llm_config=LLMConfig(
model=os.getenv("LLM_MODEL", "gpt-4o-mini"),
temperature=0.7
),
retrieval_config=RetrievalConfig(
top_k=10,
rerank_top_k=3,
use_query_rewriting=True,
use_reranking=True,
use_hybrid_search=True
),
observability_config=ObservabilityConfig(
enabled=os.getenv("ENABLE_OBSERVABILITY", "false").lower() == "true"
)
)
stats = rag.get_stats()
print(f"✅ Mini RAG initialized successfully")
print(f"📊 Total documents: {stats['total_documents']}")
except Exception as e:
print(f"❌ Failed to initialize Mini RAG: {e}")
raise
@app.on_event("shutdown")
async def shutdown_event():
"""Clean up resources on shutdown."""
global rag
if rag:
try:
rag.vector_store.disconnect()
print("✅ Mini RAG shut down successfully")
except Exception as e:
print(f"⚠️ Error during shutdown: {e}")
# API Endpoints
@app.post("/auth/login", response_model=TokenResponse)
async def login(request: LoginRequest):
"""
Authenticate and receive an access token.
In production, validate credentials against a database.
"""
# TODO: Implement proper user authentication
# This is a simplified example
if request.username == "demo" and request.password == "demo123":
token = create_access_token(request.username)
return TokenResponse(
access_token=token,
token_type="bearer",
expires_in=JWT_EXPIRATION_HOURS * 3600
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid credentials"
)
@app.post("/query", response_model=QueryResponse)
async def query(
request: QueryRequest,
user_id: str = Depends(verify_token)
):
"""
Query the RAG system with a question.
Requires authentication.
"""
if not rag:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="RAG system not initialized"
)
try:
# Query the RAG system
response = rag.query(
query=request.question,
top_k=request.top_k,
rerank_top_k=request.rerank_top_k,
return_sources=request.include_sources
)
# Format sources
sources = []
if request.include_sources:
sources = [
{
"text": chunk.text[:200] + "...",
"score": chunk.reranked_score or chunk.score,
"metadata": chunk.metadata
}
for chunk in response.retrieved_chunks
]
return QueryResponse(
answer=response.answer,
sources=sources,
metadata=response.metadata,
query_id=str(uuid.uuid4()),
timestamp=datetime.utcnow()
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Query failed: {str(e)}"
)
@app.post("/documents/upload", response_model=DocumentUploadResponse)
async def upload_document(
file: UploadFile = File(...),
user_id: str = Depends(verify_token)
):
"""
Upload and index a document.
Requires authentication.
"""
if not rag:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="RAG system not initialized"
)
# Validate file extension
file_ext = Path(file.filename).suffix.lower()
if file_ext not in ALLOWED_EXTENSIONS:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"File type not allowed. Allowed: {ALLOWED_EXTENSIONS}"
)
# Generate unique filename
document_id = str(uuid.uuid4())
filename = f"{document_id}_{file.filename}"
file_path = UPLOAD_DIR / filename
try:
# Save uploaded file
with file_path.open("wb") as buffer:
shutil.copyfileobj(file.file, buffer)
# Check file size
file_size = file_path.stat().st_size
if file_size > MAX_FILE_SIZE:
file_path.unlink()
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"File too large. Max size: {MAX_FILE_SIZE / 1024 / 1024}MB"
)
# Index the document
num_chunks = rag.index_document(
str(file_path),
metadata={
"document_id": document_id,
"filename": file.filename,
"uploaded_by": user_id,
"uploaded_at": datetime.utcnow().isoformat(),
"file_size": file_size
}
)
return DocumentUploadResponse(
filename=file.filename,
document_id=document_id,
chunks=num_chunks,
status="indexed",
timestamp=datetime.utcnow()
)
except Exception as e:
# Clean up on error
if file_path.exists():
file_path.unlink()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Upload failed: {str(e)}"
)
@app.delete("/documents/{document_id}")
async def delete_document(
document_id: str,
user_id: str = Depends(verify_token)
):
"""
Delete a document.
Requires authentication.
"""
if not rag:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="RAG system not initialized"
)
try:
# Delete from vector store
deleted = rag.vector_store.delete(f'metadata["document_id"] == "{document_id}"')
if deleted == 0:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Document not found"
)
# Delete file if exists
for file_path in UPLOAD_DIR.glob(f"{document_id}_*"):
file_path.unlink()
return {
"document_id": document_id,
"status": "deleted",
"chunks_deleted": deleted
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Deletion failed: {str(e)}"
)
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""Health check endpoint."""
is_healthy = rag is not None
total_docs = None
if is_healthy:
try:
stats = rag.get_stats()
total_docs = stats['total_documents']
except:
is_healthy = False
return HealthResponse(
status="healthy" if is_healthy else "unhealthy",
timestamp=datetime.utcnow(),
rag_initialized=is_healthy,
total_documents=total_docs
)
@app.get("/stats", response_model=StatsResponse)
async def get_stats(user_id: str = Depends(verify_token)):
"""
Get system statistics.
Requires authentication.
"""
if not rag:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="RAG system not initialized"
)
try:
stats = rag.get_stats()
return StatsResponse(
total_documents=stats['total_documents'],
collection_name=stats['collection_name'],
timestamp=datetime.utcnow()
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get stats: {str(e)}"
)
@app.get("/")
async def root():
"""Root endpoint."""
return {
"name": "Mini RAG API",
"version": "1.0.0",
"status": "running",
"docs": "/docs"
}
# Error handlers
@app.exception_handler(HTTPException)
async def http_exception_handler(request, exc):
"""Custom HTTP exception handler."""
return {
"error": exc.detail,
"status_code": exc.status_code,
"timestamp": datetime.utcnow().isoformat()
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(
app,
host="0.0.0.0",
port=int(os.getenv("PORT", 8000)),
log_level="info"
)
Running the API
Set Environment Variables
Create a
.env file:OPENAI_API_KEY=sk-...
MILVUS_URI=https://...
MILVUS_TOKEN=...
JWT_SECRET=your-secret-key-change-in-production
COLLECTION_NAME=api_documents
LLM_MODEL=gpt-4o-mini
ENABLE_OBSERVABILITY=false
PORT=8000
Run the Server
python main.py
uvicorn main:app --host 0.0.0.0 --port 8000 --reload
API Usage Examples
1. Authentication
# Login to get access token
curl -X POST "http://localhost:8000/auth/login" \
-H "Content-Type: application/json" \
-d '{
"username": "demo",
"password": "demo123"
}'
# Response
{
"access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
"token_type": "bearer",
"expires_in": 86400
}
2. Upload Document
# Upload a document
curl -X POST "http://localhost:8000/documents/upload" \
-H "Authorization: Bearer YOUR_TOKEN" \
-F "file=@document.pdf"
# Response
{
"filename": "document.pdf",
"document_id": "123e4567-e89b-12d3-a456-426614174000",
"chunks": 45,
"status": "indexed",
"timestamp": "2024-01-15T10:30:00"
}
3. Query
# Query the system
curl -X POST "http://localhost:8000/query" \
-H "Authorization: Bearer YOUR_TOKEN" \
-H "Content-Type: application/json" \
-d '{
"question": "What is the main topic?",
"top_k": 10,
"rerank_top_k": 3,
"include_sources": true
}'
# Response
{
"answer": "The main topic is...",
"sources": [
{
"text": "...",
"score": 0.95,
"metadata": {...}
}
],
"metadata": {...},
"query_id": "...",
"timestamp": "2024-01-15T10:31:00"
}
4. Get Statistics
# Get system stats
curl -X GET "http://localhost:8000/stats" \
-H "Authorization: Bearer YOUR_TOKEN"
# Response
{
"total_documents": 150,
"collection_name": "api_documents",
"timestamp": "2024-01-15T10:32:00"
}
5. Health Check
# Check system health
curl -X GET "http://localhost:8000/health"
# Response
{
"status": "healthy",
"timestamp": "2024-01-15T10:33:00",
"rag_initialized": true,
"total_documents": 150
}
Python Client
Create a Python client for your API:import requests
from typing import Optional, List
class MiniRAGClient:
"""Python client for Mini RAG API."""
def __init__(self, base_url: str, username: str, password: str):
self.base_url = base_url.rstrip('/')
self.token = None
self.login(username, password)
def login(self, username: str, password: str):
"""Authenticate and get access token."""
response = requests.post(
f"{self.base_url}/auth/login",
json={"username": username, "password": password}
)
response.raise_for_status()
self.token = response.json()["access_token"]
def _headers(self):
"""Get authentication headers."""
return {"Authorization": f"Bearer {self.token}"}
def query(
self,
question: str,
top_k: int = 10,
rerank_top_k: int = 3,
include_sources: bool = True
):
"""Query the RAG system."""
response = requests.post(
f"{self.base_url}/query",
json={
"question": question,
"top_k": top_k,
"rerank_top_k": rerank_top_k,
"include_sources": include_sources
},
headers=self._headers()
)
response.raise_for_status()
return response.json()
def upload_document(self, file_path: str):
"""Upload and index a document."""
with open(file_path, 'rb') as f:
files = {'file': f}
response = requests.post(
f"{self.base_url}/documents/upload",
files=files,
headers=self._headers()
)
response.raise_for_status()
return response.json()
def delete_document(self, document_id: str):
"""Delete a document."""
response = requests.delete(
f"{self.base_url}/documents/{document_id}",
headers=self._headers()
)
response.raise_for_status()
return response.json()
def get_stats(self):
"""Get system statistics."""
response = requests.get(
f"{self.base_url}/stats",
headers=self._headers()
)
response.raise_for_status()
return response.json()
def health_check(self):
"""Check system health."""
response = requests.get(f"{self.base_url}/health")
response.raise_for_status()
return response.json()
# Usage
client = MiniRAGClient(
base_url="http://localhost:8000",
username="demo",
password="demo123"
)
# Upload document
result = client.upload_document("document.pdf")
print(f"Uploaded: {result['filename']}, {result['chunks']} chunks")
# Query
response = client.query("What is this about?")
print(f"Answer: {response['answer']}")
# Get stats
stats = client.get_stats()
print(f"Total documents: {stats['total_documents']}")
Docker Deployment
Create aDockerfile:
FROM python:3.11-slim
WORKDIR /app
# Install dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy application
COPY . .
# Create uploads directory
RUN mkdir -p uploads
# Create non-root user
RUN useradd -m -u 1000 appuser && chown -R appuser:appuser /app
USER appuser
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD curl -f http://localhost:8000/health || exit 1
# Run application
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
# Build
docker build -t minirag-api .
# Run
docker run -p 8000:8000 \
-e OPENAI_API_KEY=sk-... \
-e MILVUS_URI=https://... \
-e MILVUS_TOKEN=... \
minirag-api
Next Steps
Production Guide
Deploy to production
Observability
Monitor your API
Configuration
Optimize settings
Security
Secure your API
