Extending Navarch#
Navarch is designed to be extended. You can add custom cloud providers, autoscaling strategies, workload notifiers, and metrics sources without modifying the core codebase.
Custom providers#
Providers handle instance lifecycle: provisioning, listing, and terminating GPU instances.
Provider interface#
type Provider interface {
// Name returns the provider identifier (e.g., "lambda", "gcp", "aws").
Name() string
// Provision creates a new instance with the given specification.
Provision(ctx context.Context, req ProvisionRequest) (*Instance, error)
// Terminate destroys an instance.
Terminate(ctx context.Context, instanceID string) error
// List returns all instances managed by this provider.
List(ctx context.Context) ([]*Instance, error)
}
ProvisionRequest#
type ProvisionRequest struct {
InstanceType string // Provider-specific instance type
Region string // Region to provision in
Zone string // Optional: specific zone
SSHKeyNames []string // SSH keys to inject
Labels map[string]string // Labels to apply
UserData string // Cloud-init or startup script
}
Example: Custom provider#
package myprovider
import (
"context"
"github.com/NavarchProject/navarch/pkg/provider"
)
type MyCloudProvider struct {
client *MyCloudClient
region string
}
func New(apiKey, region string) (*MyCloudProvider, error) {
client, err := NewMyCloudClient(apiKey)
if err != nil {
return nil, err
}
return &MyCloudProvider{client: client, region: region}, nil
}
func (p *MyCloudProvider) Name() string {
return "mycloud"
}
func (p *MyCloudProvider) Provision(ctx context.Context, req provider.ProvisionRequest) (*provider.Instance, error) {
resp, err := p.client.CreateInstance(ctx, CreateRequest{
Type: req.InstanceType,
Region: p.region,
SSHKeys: req.SSHKeyNames,
UserData: req.UserData,
})
if err != nil {
return nil, fmt.Errorf("create instance: %w", err)
}
return &provider.Instance{
ID: resp.ID,
ProviderName: p.Name(),
Status: provider.StatusPending,
Region: p.region,
InstanceType: req.InstanceType,
}, nil
}
func (p *MyCloudProvider) Terminate(ctx context.Context, instanceID string) error {
return p.client.DeleteInstance(ctx, instanceID)
}
func (p *MyCloudProvider) List(ctx context.Context) ([]*provider.Instance, error) {
instances, err := p.client.ListInstances(ctx)
if err != nil {
return nil, err
}
result := make([]*provider.Instance, len(instances))
for i, inst := range instances {
result[i] = &provider.Instance{
ID: inst.ID,
ProviderName: p.Name(),
Status: mapStatus(inst.State),
PublicIP: inst.PublicIP,
PrivateIP: inst.PrivateIP,
Region: inst.Region,
}
}
return result, nil
}
See pkg/provider/lambda/ for a production implementation. For testing, see the Docker provider.
Registering your provider#
myCloud, err := myprovider.New(os.Getenv("MYCLOUD_API_KEY"), "us-east-1")
if err != nil {
log.Fatal(err)
}
controlPlane.RegisterProvider(myCloud)
Custom autoscalers#
Autoscalers decide when to scale pools up or down based on metrics.
Autoscaler interface#
type Autoscaler interface {
// Name returns the autoscaler type identifier.
Name() string
// Recommend returns the target node count for the pool.
Recommend(ctx context.Context, state PoolState) (ScaleRecommendation, error)
}
type PoolState struct {
CurrentNodes int
MinNodes int
MaxNodes int
Metrics PoolMetrics
}
type PoolMetrics struct {
Utilization float64 // Average GPU utilization (0-100)
PendingJobs int // Jobs waiting to be scheduled
QueueDepth int // Total jobs in queue
UtilizationHistory []float64 // Historical samples for trend analysis
}
type ScaleRecommendation struct {
TargetNodes int
Reason string
}
Example: Custom autoscaler#
A cost-aware autoscaler that scales down aggressively outside business hours:
package costscaler
import (
"context"
"time"
"github.com/NavarchProject/navarch/pkg/pool"
)
type CostAwareScaler struct {
peakStart int // Hour (0-23)
peakEnd int
peakThreshold float64 // Scale up when utilization exceeds this
offPeakMin int // Minimum nodes outside peak hours
}
func New(peakStart, peakEnd int, threshold float64, offPeakMin int) *CostAwareScaler {
return &CostAwareScaler{
peakStart: peakStart,
peakEnd: peakEnd,
peakThreshold: threshold,
offPeakMin: offPeakMin,
}
}
func (s *CostAwareScaler) Name() string {
return "cost-aware"
}
func (s *CostAwareScaler) Recommend(ctx context.Context, state pool.PoolState) (pool.ScaleRecommendation, error) {
hour := time.Now().Hour()
isPeak := hour >= s.peakStart && hour < s.peakEnd
current := state.CurrentNodes
util := state.Metrics.Utilization
// Outside peak hours, scale to minimum
if !isPeak {
target := s.offPeakMin
if target < state.MinNodes {
target = state.MinNodes
}
return pool.ScaleRecommendation{
TargetNodes: target,
Reason: "off-peak hours, scaling to minimum",
}, nil
}
// During peak, scale based on utilization
if util > s.peakThreshold && current < state.MaxNodes {
return pool.ScaleRecommendation{
TargetNodes: current + 1,
Reason: fmt.Sprintf("utilization %.1f%% exceeds threshold", util),
}, nil
}
if util < s.peakThreshold/2 && current > state.MinNodes {
return pool.ScaleRecommendation{
TargetNodes: current - 1,
Reason: fmt.Sprintf("utilization %.1f%% below threshold", util),
}, nil
}
return pool.ScaleRecommendation{
TargetNodes: current,
Reason: "no change needed",
}, nil
}
See pkg/pool/autoscaler.go for built-in implementations (reactive, queue, scheduled, predictive, composite).
Custom metrics sources#
For queue-based or custom autoscaling, provide metrics from your scheduler.
MetricsSource interface#
type MetricsSource interface {
GetPoolMetrics(ctx context.Context, poolName string) (*PoolMetrics, error)
}
Example: Kubernetes integration#
type K8sMetricsSource struct {
client *kubernetes.Clientset
namespace string
}
func (m *K8sMetricsSource) GetPoolMetrics(ctx context.Context, poolName string) (*PoolMetrics, error) {
pods, err := m.client.CoreV1().Pods(m.namespace).List(ctx, metav1.ListOptions{
LabelSelector: "navarch.dev/pool=" + poolName,
})
if err != nil {
return nil, err
}
var pending, running int
for _, pod := range pods.Items {
switch pod.Status.Phase {
case corev1.PodPending:
pending++
case corev1.PodRunning:
running++
}
}
return &PoolMetrics{
PendingJobs: pending,
QueueDepth: pending + running,
}, nil
}
Example: Slurm integration#
type SlurmMetricsSource struct {
partition string
}
func (m *SlurmMetricsSource) GetPoolMetrics(ctx context.Context, poolName string) (*PoolMetrics, error) {
// Query pending jobs
out, err := exec.CommandContext(ctx, "squeue", "-p", m.partition, "-t", "PENDING", "-h", "-o", "%i").Output()
if err != nil {
return nil, err
}
pending := len(strings.Split(strings.TrimSpace(string(out)), "\n"))
// Query running jobs
out, err = exec.CommandContext(ctx, "squeue", "-p", m.partition, "-t", "RUNNING", "-h", "-o", "%i").Output()
if err != nil {
return nil, err
}
running := len(strings.Split(strings.TrimSpace(string(out)), "\n"))
return &PoolMetrics{
PendingJobs: pending,
QueueDepth: pending + running,
}, nil
}
Custom notifiers#
Notifiers integrate Navarch with your workload management system. When Navarch cordons, drains, or uncordons a node, the notifier tells your scheduler (Kubernetes, Slurm, Ray, or custom) to stop scheduling work and migrate existing workloads.
Notifier interface#
type Notifier interface {
// Cordon marks a node as unschedulable. The workload system should stop
// placing new workloads on this node. Existing workloads continue.
Cordon(ctx context.Context, nodeID string, reason string) error
// Uncordon marks a node as schedulable again.
Uncordon(ctx context.Context, nodeID string) error
// Drain requests the workload system to migrate workloads off the node.
Drain(ctx context.Context, nodeID string, reason string) error
// IsDrained returns true if all workloads have been migrated off
// the node and it's safe to terminate.
IsDrained(ctx context.Context, nodeID string) (bool, error)
// Name returns the notifier name for logging.
Name() string
}
Example: Kubernetes notifier#
package k8snotifier
import (
"context"
"fmt"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes"
)
type K8sNotifier struct {
client *kubernetes.Clientset
}
func New(client *kubernetes.Clientset) *K8sNotifier {
return &K8sNotifier{client: client}
}
func (n *K8sNotifier) Name() string {
return "kubernetes"
}
func (n *K8sNotifier) Cordon(ctx context.Context, nodeID string, reason string) error {
node, err := n.client.CoreV1().Nodes().Get(ctx, nodeID, metav1.GetOptions{})
if err != nil {
return fmt.Errorf("get node: %w", err)
}
node.Spec.Unschedulable = true
_, err = n.client.CoreV1().Nodes().Update(ctx, node, metav1.UpdateOptions{})
return err
}
func (n *K8sNotifier) Uncordon(ctx context.Context, nodeID string) error {
node, err := n.client.CoreV1().Nodes().Get(ctx, nodeID, metav1.GetOptions{})
if err != nil {
return fmt.Errorf("get node: %w", err)
}
node.Spec.Unschedulable = false
_, err = n.client.CoreV1().Nodes().Update(ctx, node, metav1.UpdateOptions{})
return err
}
func (n *K8sNotifier) Drain(ctx context.Context, nodeID string, reason string) error {
// List pods on the node
pods, err := n.client.CoreV1().Pods("").List(ctx, metav1.ListOptions{
FieldSelector: "spec.nodeName=" + nodeID,
})
if err != nil {
return fmt.Errorf("list pods: %w", err)
}
// Evict each pod
for _, pod := range pods.Items {
if pod.Namespace == "kube-system" {
continue // Skip system pods
}
eviction := &policyv1.Eviction{
ObjectMeta: metav1.ObjectMeta{
Name: pod.Name,
Namespace: pod.Namespace,
},
}
n.client.CoreV1().Pods(pod.Namespace).Evict(ctx, eviction)
}
return nil
}
func (n *K8sNotifier) IsDrained(ctx context.Context, nodeID string) (bool, error) {
pods, err := n.client.CoreV1().Pods("").List(ctx, metav1.ListOptions{
FieldSelector: "spec.nodeName=" + nodeID,
})
if err != nil {
return false, fmt.Errorf("list pods: %w", err)
}
// Check if only system pods remain
for _, pod := range pods.Items {
if pod.Namespace != "kube-system" && pod.Status.Phase == corev1.PodRunning {
return false, nil
}
}
return true, nil
}
Example: Slurm notifier#
package slurmnotifier
import (
"context"
"fmt"
"os/exec"
"strings"
)
type SlurmNotifier struct{}
func New() *SlurmNotifier {
return &SlurmNotifier{}
}
func (n *SlurmNotifier) Name() string {
return "slurm"
}
func (n *SlurmNotifier) Cordon(ctx context.Context, nodeID string, reason string) error {
// Set node to drain state (no new jobs)
cmd := exec.CommandContext(ctx, "scontrol", "update", "nodename="+nodeID, "state=drain", "reason="+reason)
return cmd.Run()
}
func (n *SlurmNotifier) Uncordon(ctx context.Context, nodeID string) error {
cmd := exec.CommandContext(ctx, "scontrol", "update", "nodename="+nodeID, "state=resume")
return cmd.Run()
}
func (n *SlurmNotifier) Drain(ctx context.Context, nodeID string, reason string) error {
// Slurm drain state already prevents new jobs; jobs finish naturally
return n.Cordon(ctx, nodeID, reason)
}
func (n *SlurmNotifier) IsDrained(ctx context.Context, nodeID string) (bool, error) {
// Check if any jobs are running on this node
cmd := exec.CommandContext(ctx, "squeue", "-w", nodeID, "-h", "-o", "%i")
out, err := cmd.Output()
if err != nil {
return false, fmt.Errorf("squeue: %w", err)
}
return strings.TrimSpace(string(out)) == "", nil
}
Built-in notifiers#
Navarch includes two built-in notifiers:
| Notifier | Description |
|---|---|
noop |
Logs operations but takes no action. Default when no notifier is configured. |
webhook |
Sends HTTP requests to your workload system. See configuration. |
See pkg/notifier/ for implementation details.
Testing extensions#
Unit tests#
func TestCostAwareScaler_OffPeak(t *testing.T) {
scaler := costscaler.New(9, 18, 80.0, 2)
state := pool.PoolState{
CurrentNodes: 10,
MinNodes: 1,
MaxNodes: 20,
Metrics: pool.PoolMetrics{Utilization: 50.0},
}
// Test at 3am (off-peak)
rec, err := scaler.Recommend(context.Background(), state)
require.NoError(t, err)
assert.Equal(t, 2, rec.TargetNodes) // Should scale to off-peak minimum
}
Integration tests with simulator#
Test your extensions with realistic scenarios:
name: test-custom-autoscaler
fleet:
- id: node-1
provider: fake
gpu_count: 8
- id: node-2
provider: fake
gpu_count: 8
events:
- at: 0s
action: start_fleet
- at: 30s
action: assert
target: node-1
params:
expected_status: active
Best practices#
-
Handle context cancellation: Check
ctx.Done()in long-running operations. -
Return wrapped errors: Use
fmt.Errorf("operation: %w", err)for debuggable errors. -
Log at appropriate levels: Debug for verbose details, Info for operations, Error for failures.
-
Test with the simulator: Validate behavior before deploying.
-
Document configuration options: Make your extension easy to configure.