Skip to main content
πŸŽ“ Claude Code Masterclass Learn AI-assisted development on Udemy β€” plus the companion book on Leanpub & Amazon. Start Learning
Fine-tuning Mistral with FSDP and LoRA on OpenShift multi-node GPU training
AI

Fine-Tuning Mistral with FSDP and LoRA on OpenShift

A production guide to fine-tuning Mistral models using FSDP2, LoRA/PEFT, and Run:ai on OpenShift with InfiniBand RDMA. Covers NCCL configuration, memory.

LB
Luca Berton
Β· 6 min read

Fine-tuning large language models in enterprise environments requires more than just a training script. You need multi-node orchestration, InfiniBand networking, memory-efficient sharding, and production-grade job submission. Here’s a complete walkthrough of fine-tuning Mistral 4 Small using FSDP2 + LoRA (PEFT) across multiple GPU nodes on OpenShift with Run:ai.

Architecture Overview

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  Run:ai Job Scheduler                                β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”β”‚
β”‚  β”‚  Node 0      β”‚  β”‚  Node 1      β”‚  β”‚  Node 2    β”‚β”‚
β”‚  β”‚  1Γ— GPU      β”‚  β”‚  1Γ— GPU      β”‚  β”‚  1Γ— GPU    β”‚β”‚
β”‚  β”‚  FSDP Shard  β”‚  β”‚  FSDP Shard  β”‚  β”‚  FSDP Shardβ”‚β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜  β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜  β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”˜β”‚
β”‚         β”‚    InfiniBand / Mellanox RDMA      β”‚      β”‚
β”‚         β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜      β”‚
β”‚                     NCCL AllReduce                    β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

FSDP Configuration

Fully Sharded Data Parallel (FSDP) version 2 distributes model parameters, gradients, and optimizer states across all workers. This is the fsdp.yaml accelerate configuration:

compute_environment: LOCAL_MACHINE
debug: true
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
mixed_precision: bf16
num_machines: 3
num_processes: 3
rdzv_backend: static
same_network: true
use_cpu: false

fsdp_config:
  fsdp_version: 2

  # CRITICAL for PEFT compatibility with FSDP2
  fsdp_use_orig_params: true

  # --- Wrapping ---
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer

  # --- Sharding ---
  # Equivalent to FULL_SHARD in FSDP2: re-shard parameters after
  # each forward pass to free memory between micro-batches
  fsdp_reshard_after_forward: true
  fsdp_limit_all_gathers: true

  # --- Memory / Loading ---
  fsdp_cpu_ram_efficient_loading: true
  fsdp_sync_module_states: true
  fsdp_offload_params: true

  # --- Prefetch ---
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_forward_prefetch: true

  # --- State dict ---
  fsdp_state_dict_type: FULL_STATE_DICT
  state_dict_cpu_offload: true

  # Gradient checkpointing managed by SFTConfig
  fsdp_activation_checkpointing: false

Key Configuration Decisions

fsdp_use_orig_params: true β€” This is critical when combining FSDP with PEFT/LoRA. Without it, FSDP reshards parameters in a way that breaks LoRA’s adapter structure.

fsdp_reshard_after_forward: true β€” Equivalent to FULL_SHARD in FSDP1. After the forward pass on each layer, parameters are re-sharded across workers. This minimizes peak memory at the cost of extra communication.

TRANSFORMER_BASED_WRAP β€” Wraps each MistralDecoderLayer as an FSDP unit. This gives optimal granularity β€” fine enough for memory distribution, coarse enough to minimize communication overhead.

fsdp_offload_params: true β€” Offloads optimizer states to CPU RAM. Essential when GPU memory is tight, though it increases training time by 20-30%.

LoRA Training Configuration

The training configuration uses parameter-efficient fine-tuning to train only a small subset of the model:

# --- LoRA / PEFT ---
lora_dropout: 0.05
target_modules:
  - q_proj
  - k_proj
  - v_proj
  - o_proj
  - gate_proj
  - up_proj
  - down_proj
bias: none
task_type: CAUSAL_LM

# --- Dataset ---
data:
  dataset_name_or_path: /data/input/Datasets/Salesforce/wikitext
  dataset_config_name: wikitext-2-raw-v1
  text_field: text
  train_split: train
  val_split: validation
  max_length: 1024
  prompt_template: null

# --- Training ---
train:
  output_dir: /data/output/Models/mistral4small-lora-fsdp-2xH200
  num_train_epochs: 1

  # FSDP: batch per GPU Γ— 2 GPUs = effective batch before accumulation
  per_device_train_batch_size: 1
  per_device_eval_batch_size: 1
  gradient_accumulation_steps: 16  # effective batch = 1 Γ— 3 GPUs Γ— 16 = 48

  gradient_checkpointing: false

  # adamw_torch recommended with FSDP (not paged β€” no CPU offload needed)
  optim: adamw_torch

  learning_rate: 0.0001
  lr_scheduler_type: cosine
  warmup_ratio: 0.05
  weight_decay: 0.01

  bf16: true
  fp16: false

  logging_steps: 25
  eval_strategy: steps
  eval_steps: 200
  save_strategy: steps
  save_steps: 200
  save_total_limit: 3
  load_best_model_at_end: false  # FSDP + load_best incompatible

  report_to: none
  seed: 42
  ddp_find_unused_parameters: false

Training Math

With 3 nodes Γ— 1 GPU each:

  • Per-device batch size: 1
  • Gradient accumulation: 16
  • Effective batch size: 1 Γ— 3 Γ— 16 = 48 samples

This keeps per-GPU memory low while achieving a reasonable global batch size for stable training.

Why These LoRA Targets?

Targeting all attention projections (q_proj, k_proj, v_proj, o_proj) plus the MLP layers (gate_proj, up_proj, down_proj) gives near-full-fine-tuning quality with under 2% of trainable parameters. For Mistral 4 Small, this typically means:

  • Total parameters: ~12B
  • Trainable LoRA parameters: ~200M (1.6%)
  • Memory savings: 60-70% vs full fine-tuning

Multi-Node Job Submission with Run:ai

The job submission script orchestrates everything on OpenShift:

#!/bin/bash
echo "Submitting fine-tuning job"

# Windows/MSYS2 path conversion prevention
export MSYS_NO_PATHCONV=1
export MSYS2_ARG_CONV_EXCL="*"

IMAGE="quay.io/your-registry/nvidia/pytorch@sha256:..."
JOB_NAME="mistral4small-fsdp"

runai training pytorch submit $JOB_NAME \
  --image $IMAGE \
  --annotation "k8s.v1.cni.cncf.io/networks=ssa" \
  --extended-resource "openshift.io/mellanoxnics=1" \
  --large-shm \
  --workers 2 \
  --gpu-devices-request 1 \
  --cpu-memory-request 3846 \
  --cpu-memory-limit 8996 \
  --run-as-uid 2000 \
  --run-as-gid 2000 \
  --working-dir /data/scripts/ia-gen-bench/llm/finetune-peft \
  --environment-variable PYTORCH_ALLOC_CONF=expandable_segments:True \
  --environment-variable NCCL_DEBUG="INFO" \
  --environment-variable NCCL_IB_QPS_PER_CONNECTION=1 \
  --environment-variable NCCL_IB_SPLIT_DATA_ON_QPS=1 \
  --environment-variable NCCL_SOCKET_NTHREADS=2 \
  --environment-variable NCCL_NSOCKS_PERTHREAD=2 \
  --environment-variable NCCL_SOCKET_IFNAME="net1" \
  --environment-variable CUDA_VISIBLE_DEVICES=0 \
  --environment-variable RDMAV_FORK_SAFE=1 \
  --environment-variable PIP_INDEX=https://your-artifactory/api/pypi/virtual/pypi \
  --environment-variable PIP_INDEX_URL=https://your-artifactory/api/pypi/virtual/simple \
  --environment-variable PIP_TRUSTED_HOST=your-artifactory.internal \
  --existing-pvc claimname=project-001,path=/data \
  --command -- /data/scripts/ia-gen-bench/shell/accelerate-peft-fsdp.sh

Key Infrastructure Decisions

--extended-resource "openshift.io/mellanoxnics=1" β€” Requests a Mellanox InfiniBand NIC for RDMA communication between nodes. Without this, NCCL falls back to TCP, which is 10-50x slower for AllReduce operations.

--annotation "k8s.v1.cni.cncf.io/networks=ssa" β€” Attaches the secondary network (SR-IOV or Macvlan) for the InfiniBand fabric. This is separate from the pod’s primary network.

--large-shm β€” Mounts a large /dev/shm for PyTorch’s shared memory data loader workers. Without this, multi-worker data loading fails silently.

--workers 2 β€” Creates 2 additional worker pods (total 3 nodes: 1 master + 2 workers).

NCCL Environment Variables Explained

These environment variables tune NCCL for InfiniBand RDMA performance:

VariableValuePurpose
NCCL_DEBUGINFOLog NCCL initialization (verify IB detection)
NCCL_IB_QPS_PER_CONNECTION1Queue Pairs per connection (balance bandwidth)
NCCL_IB_SPLIT_DATA_ON_QPS1Split data across QPs for parallelism
NCCL_SOCKET_NTHREADS2Socket threads for control plane
NCCL_NSOCKS_PERTHREAD2Sockets per thread
NCCL_SOCKET_IFNAMEnet1Use the SR-IOV/Mellanox interface, not eth0
CUDA_VISIBLE_DEVICES0Each pod sees one GPU
RDMAV_FORK_SAFE1Allow RDMA after fork (required for PyTorch DataLoader)

The NCCL_SOCKET_IFNAME="net1" is critical β€” it tells NCCL to use the InfiniBand interface instead of the default Kubernetes pod network. Without this, all gradient synchronization goes over slow TCP.

Memory Optimization Stack

The configuration stacks multiple memory optimizations:

  1. LoRA β€” only 1.6% of parameters are trainable
  2. FSDP FULL_SHARD β€” parameters sharded across 3 nodes
  3. BF16 mixed precision β€” halves activation memory
  4. CPU offloading β€” optimizer states on CPU RAM
  5. Gradient accumulation (16) β€” tiny per-step batch = low activation memory
  6. Parameter prefetching β€” overlaps communication with computation

This stack allows fine-tuning a 12B model with a single GPU per node β€” significantly more accessible than the 8Γ—H100 setups typically required.

Why FSDP2 Over DeepSpeed?

FeatureFSDP2DeepSpeed ZeRO-3
PyTorch nativeβœ… Yes❌ Separate library
PEFT compatibilityβœ… orig_params=true⚠️ Requires workarounds
Torch compileβœ… Supported❌ Limited
MaintenancePyTorch teamMicrosoft
ConfigurationAccelerate YAMLJSON config
Activation checkpointingPer-layer controlGlobal toggle

FSDP2 is the recommended approach for new projects in 2026, especially when using LoRA/PEFT adapters.

Production Considerations

Checkpoint Strategy

save_strategy: steps
save_steps: 200
save_total_limit: 3
load_best_model_at_end: false  # Incompatible with FSDP

With FSDP, load_best_model_at_end doesn’t work because the model is sharded across workers. Instead, save checkpoints and select the best one post-training based on validation loss.

Monitoring Training

Set logging_steps: 25 to get frequent loss updates. For multi-node training, watch for:

  • Loss divergence (learning rate too high)
  • NCCL timeouts (network issues)
  • OOM on specific workers (memory imbalance)
  • Gradient norm spikes (data quality issues)

Data Pipeline

The configuration uses WikiText-2 as a public benchmark dataset. In production, you would replace this with your domain-specific dataset while keeping the same data pipeline structure:

data:
  dataset_name_or_path: /data/input/Datasets/your-domain-corpus
  text_field: text
  train_split: train
  val_split: validation
  max_length: 1024

The max_length: 1024 keeps memory bounded. For longer contexts, reduce batch size or add gradient checkpointing.

Scaling Considerations

To scale from 3 to more nodes:

  1. Update num_machines and num_processes in fsdp.yaml
  2. Update --workers in the Run:ai submit command
  3. Adjust gradient_accumulation_steps to maintain effective batch size
  4. Monitor InfiniBand bandwidth β€” more nodes = more AllReduce communication

The sweet spot for Mistral 4 Small with LoRA is typically 2-4 nodes. Beyond that, communication overhead starts to dominate training time.

Prerequisites: PyTorch and Transformers Versions

Fine-tuning Mistral 4 Small requires specific version alignment:

  • PyTorch container: nvcr.io/nvidia/pytorch:26.02 (CUDA 12.8, PyTorch 2.11)
  • Transformers: Latest version (4.48+) β€” Mistral 4 architecture (Mistral4ForCausalLM) was added recently and is not available in older releases
  • PEFT: 0.14+ for FSDP2 compatibility
  • TRL: 0.15+ for SFTTrainer with processing_class parameter
  • Accelerate: 1.3+ for FSDP2 configuration support
  • torchao: Latest (for quantization-aware training support)
  • pynvml: For GPU memory monitoring
# Install on top of the PyTorch 26.02 NGC container
pip install --quiet peft datasets trl transformers accelerate pynvml
# Update torchao for latest quantization support
pip install -U torchao

The Mistral4ForCausalLM class was merged into transformers after the Mistral 4 Small release. If you get ImportError: cannot import name 'Mistral4ForCausalLM', upgrade transformers to the latest version or install from source:

uv pip install git+https://github.com/huggingface/transformers.git

The Accelerate Launch Script

The shell script that bootstraps the FSDP training run handles dependency installation, environment setup, NVLink verification, and accelerate launch:

#!/bin/bash
set -euo pipefail

# --- FSDP Configuration Notes ---
# Backward prefetch: during backward of layer i, prefetch layer i-1 in parallel
# Masks allgather NVLink latency behind GPU computation
#
# State dict: FULL_STATE_DICT
# Rank 0 reconstitutes the full checkpoint at each save
# β†’ directly reloadable for inference (no FSDP needed)
# Alternative: SHARDED_STATE_DICT (faster, but requires FSDP to reload)
#
# LoRA + FSDP: fsdp_use_orig_params
# CRITICAL: without this flag, FSDP flattens parameters into tensor IDs,
# breaking PEFT's selective grad β€” impossible to train only LoRA adapters
#
# RAM-efficient loading:
# fsdp_cpu_ram_efficient_loading: only rank 0 loads 238 GB from disk
# fsdp_sync_module_states: rank 0 broadcasts weights to others via NVLink
# Without these: each process tries to load 238 GB β†’ deadlock/OOM

# --- Dependencies ---
# Install from internal mirror (PIP_INDEX_URL configured by Run:ai)
pip install --quiet peft datasets trl transformers accelerate pynvml

echo "updating torchao..."
pip install -U torchao

# Make pip binaries available (accelerate, etc.)
export PATH=$PATH:$HOME/.local/bin

# --- Environment Variables ---
# Avoid deadlocks with HuggingFace fast tokenizers + multiprocess
export TOKENIZERS_PARALLELISM=false

# Better GPU memory fragmentation management (reduces OOM)
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

# Working directory
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(dirname "$SCRIPT_DIR")"
FINETUNE_DIR="$REPO_ROOT/llm/finetune-peft"

# --- Startup Traces ---
echo ""
echo " FSDP fine-tuning β€” Mistral-Small-4 119B β€” 2Γ— H200 NVL"
echo ""
echo " date       : $(date '+%Y-%m-%d %H:%M:%S')"
echo " hostname   : $(hostname)"
echo " RANK       : ${RANK:-0}"
echo " MASTER_ADDR: ${MASTER_ADDR:-localhost}"
echo " MASTER_PORT: ${MASTER_PORT:-29500}"
echo " FINETUNE_DIR: $FINETUNE_DIR"
echo ""

# GPU topology check
nvidia-smi --query=gpu:index,name,memory.total,driver_version \
           --format=csv,noheader

echo ""

# NVLink: verify topology before launch
nvidia-smi nvlink --status -i 0 2>/dev/null | grep -E "Link|Active|Inactive" \
  || echo "NVLink status unavailable"

Key Script Design Decisions

  1. TOKENIZERS_PARALLELISM=false β€” HuggingFace tokenizers use Rust multiprocessing internally. When combined with PyTorch DataLoader workers, this causes deadlocks. Disabling it forces sequential tokenization (negligible performance impact since tokenization is not the bottleneck).

  2. PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True β€” PyTorch 2.x memory allocator feature that allows CUDA memory segments to grow dynamically. Reduces fragmentation-induced OOM errors by 40-60% for large model training.

  3. NVLink verification β€” The script checks NVLink connectivity before launching training. If NVLink is inactive, NCCL falls back to PCIe, reducing allgather bandwidth from 900 GB/s to 64 GB/s β€” making FSDP training impractically slow.

  4. RAM-efficient loading explained β€” The French comments in the original script document a critical insight: without fsdp_cpu_ram_efficient_loading, each of 3 processes would independently load 238 GB of model weights from disk, requiring 714 GB of CPU RAM and likely deadlocking. With it, only rank 0 loads, then broadcasts via NVLink.

The Complete Training Script

The Python training script ties everything together using HuggingFace’s trl (SFTTrainer) and peft libraries, with Axolotl-inspired patterns:

## Reference: https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/mistral4

# Usage:
# CUDA_VISIBLE_DEVICES=0,1 python finetune_mistral4small.py \
#   --config configs/mistral4small_2xH200.yaml
#
# With overrides:
# CUDA_VISIBLE_DEVICES=0,1 python finetune_mistral4small.py \
#   --config configs/mistral4small_2xH200.yaml \
#   --override train.learning_rate=1e-4 --override lora.r=32

import argparse
import logging
import os
from dataclasses import asdict, dataclass, field, fields
from typing import Any, Optional

import torch
import yaml
from datasets import load_dataset
from peft import LoraConfig, TaskType
from transformers import AutoConfig, AutoTokenizer, Mistral4ForCausalLM
from trl import SFTConfig, SFTTrainer
from peft import get_peft_model

from llm_train_utils import LLMStepProfiler, print_hw_summary, _nval_gpu_stats

Configuration Dataclasses

The script uses typed dataclasses for clean configuration management:

@dataclass
class ModelConfig:
    model_name_or_path: str = "/data/input/Models/Mistral-Small-4-119B-2603"
    torch_dtype: str = "bfloat16"
    device_map: str = "auto"
    use_cache: bool = False
    attn_implementation: Optional[str] = None

@dataclass
class LoRAConfig:
    r: int = 16
    lora_alpha: int = 32
    lora_dropout: float = 0.05
    target_modules: list = field(
        default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"]
    )
    bias: str = "none"
    task_type: str = "CAUSAL_LM"

Note the LoRA rank of 16 with alpha of 32 (ratio 2:1) β€” this is the standard scaling that produces stable training. The target modules here focus on attention projections only, while the YAML config version also includes MLP layers for higher quality.

SFTTrainer Integration

The training loop uses HuggingFace’s SFTTrainer with full configuration passthrough:

sft_config = SFTConfig(
    output_dir=train_cfg.output_dir,
    num_train_epochs=train_cfg.num_train_epochs,
    per_device_train_batch_size=train_cfg.per_device_train_batch_size,
    per_device_eval_batch_size=train_cfg.per_device_eval_batch_size,
    gradient_accumulation_steps=train_cfg.gradient_accumulation_steps,
    gradient_checkpointing=train_cfg.gradient_checkpointing,
    optim=train_cfg.optim,
    learning_rate=train_cfg.learning_rate,
    lr_scheduler_type=train_cfg.lr_scheduler_type,
    warmup_ratio=train_cfg.warmup_ratio,
    weight_decay=train_cfg.weight_decay,
    bf16=train_cfg.bf16,
    fp16=train_cfg.fp16,
    logging_steps=train_cfg.logging_steps,
    eval_strategy=train_cfg.eval_strategy,
    eval_steps=train_cfg.eval_steps,
    save_strategy=train_cfg.save_strategy,
    save_steps=train_cfg.save_steps,
    save_total_limit=train_cfg.save_total_limit,
    seed=train_cfg.seed,
    dataset_text_field=data_cfg.text_field,
    max_length=data_cfg.max_length,
    packing=False
)

trainer = SFTTrainer(
    model=model,
    args=sft_config,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    processing_class=tokenizer,
    callbacks=[LLMStepProfiler(
        debug_freq=args.debug_freq,
        seq_len=data_cfg.max_length,
        output_dir=train_cfg.output_dir,
        report_meta={
            "model": {"name": model_cfg.model_name_or_path, "dtype": model_cfg.torch_dtype},
            "lora": {"r": lora_cfg.r, "alpha": lora_cfg.lora_alpha, "modules": lora_cfg.target_modules},
            "data": {"dataset": data_cfg.dataset_name_or_path, "max_length": data_cfg.max_length},
            "train": {"epochs": train_cfg.num_train_epochs, "lr": train_cfg.learning_rate,
                      "batch": train_cfg.per_device_train_batch_size,
                      "accum": train_cfg.gradient_accumulation_steps},
        },
        script_name="finetune_mistral4small",
    )],
)

if trainer.accelerator.is_main_process and hasattr(trainer.model, "print_trainable_parameters"):
    trainer.model.print_trainable_parameters()

logger.info("Starting training...")
trainer.train()
logger.info("Training complete.")

# Save final model
save_path = os.path.join(train_cfg.output_dir, "final")
trainer.save_model(save_path)
tokenizer.save_pretrained(save_path)
logger.info(f"Model saved to {save_path}")

Key Design Patterns

  1. LLMStepProfiler callback β€” Custom profiler that tracks GPU utilization, memory, and throughput per step. Essential for optimizing training efficiency.

  2. packing=False β€” Disables sequence packing. While packing improves throughput by concatenating short sequences, it can affect training quality for instruction-tuned models.

  3. report_meta dictionary β€” Logs all hyperparameters to the profiler for experiment tracking without external tools like W&B or MLflow.

  4. Final model save β€” Saves only the LoRA adapter weights (not the full model), keeping checkpoint sizes at ~400MB instead of 24GB.


FSDP2 + LoRA is the most memory-efficient path to fine-tuning large models on limited GPU hardware. Start with the smallest number of nodes that fits your model, and scale only when training time becomes the bottleneck.

Free 30-min AI & Cloud consultation

Book Now