Overview
The Jamba models can be fine-tuned using several approaches:
- Full Fine-tuning: Complete model parameter updates (requires significant GPU resources)
- LoRA (Low-Rank Adaptation): Parameter-efficient fine-tuning approach
- QLoRA: Combines LoRA with 4-bit quantization for single GPU training
Full Fine-tuning
Full fine-tuning updates all model parameters and provides the most comprehensive training results.
For a comprehensive implementation guide using AWS SageMaker with multi-node and FSDP configuration, see the AI21 SageMaker Fine-tuning Repository.
Full fine-tuning requires multiple high-memory GPUs.
LoRA Fine-tuning
LoRA (Low-Rank Adaptation) fine-tuning injects compact, low-rank adapter layers into a frozen pretrained model—letting you specialize it for your task with just a few percent of the parameters, minimal extra compute and storage and with a small loss in accuracy or inference speed.
Prerequisites
Before starting LoRA fine-tuning, install the required dependencies:
pip install trl transformers torch datasets peft
This LoRA fine-tuning example uses bfloat16 precision and requires ~130GB GPU RAM (e.g., 2x A100 80GB GPUs).
Implementation
Load Model and Tokenizer
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig
tokenizer = AutoTokenizer.from_pretrained("ai21labs/AI21-Jamba-Mini-1.7")
model = AutoModelForCausalLM.from_pretrained(
"ai21labs/AI21-Jamba-Mini-1.7",
device_map="auto", # Automatically distribute across available GPUs
torch_dtype=torch.bfloat16, # Use mixed precision for memory efficiency
attn_implementation="flash_attention_2", # Optimized attention implementation
)
Configure LoRA Parameters
lora_config = LoraConfig(
r=8, # Rank of adaptation - controls the number of trainable parameters
target_modules=[
"embed_tokens",
"x_proj", "in_proj", "out_proj", # mamba layers
"gate_proj", "up_proj", "down_proj", # mlp layers
"q_proj", "k_proj", "v_proj", "o_proj", # attention layers
],
task_type="CAUSAL_LM",
bias="none",
)
Prepare Your Dataset
# Load dataset (replace with your own dataset)
dataset = load_dataset("philschmid/dolly-15k-oai-style", split="train")
Configure Training Settings
training_args = SFTConfig(
output_dir="/dev/shm/results", # Where to save the model
logging_dir="./logs", # Where to save training logs
num_train_epochs=2, # Number of training epochs
per_device_train_batch_size=4, # Batch size per GPU
learning_rate=1e-5, # Learning rate for fine-tuning
logging_steps=10, # Log training metrics every 10 steps
gradient_checkpointing=True, # Save memory at cost of compute
max_seq_length=4096, # Maximum sequence length
save_steps=100, # Save model checkpoint every 100 steps
)
Initialize and Start Training
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
peft_config=lora_config,
train_dataset=dataset,
)
trainer.train()
The dataset in this example uses conversational format (with messages
column), so SFTTrainer
automatically applies Jamba’s chat template. For more information about supported dataset formats and advanced SFTTrainer features, see the TRL documentation.
Load Model and Tokenizer
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig
tokenizer = AutoTokenizer.from_pretrained("ai21labs/AI21-Jamba-Mini-1.7")
model = AutoModelForCausalLM.from_pretrained(
"ai21labs/AI21-Jamba-Mini-1.7",
device_map="auto", # Automatically distribute across available GPUs
torch_dtype=torch.bfloat16, # Use mixed precision for memory efficiency
attn_implementation="flash_attention_2", # Optimized attention implementation
)
Configure LoRA Parameters
lora_config = LoraConfig(
r=8, # Rank of adaptation - controls the number of trainable parameters
target_modules=[
"embed_tokens",
"x_proj", "in_proj", "out_proj", # mamba layers
"gate_proj", "up_proj", "down_proj", # mlp layers
"q_proj", "k_proj", "v_proj", "o_proj", # attention layers
],
task_type="CAUSAL_LM",
bias="none",
)
Prepare Your Dataset
# Load dataset (replace with your own dataset)
dataset = load_dataset("philschmid/dolly-15k-oai-style", split="train")
Configure Training Settings
training_args = SFTConfig(
output_dir="/dev/shm/results", # Where to save the model
logging_dir="./logs", # Where to save training logs
num_train_epochs=2, # Number of training epochs
per_device_train_batch_size=4, # Batch size per GPU
learning_rate=1e-5, # Learning rate for fine-tuning
logging_steps=10, # Log training metrics every 10 steps
gradient_checkpointing=True, # Save memory at cost of compute
max_seq_length=4096, # Maximum sequence length
save_steps=100, # Save model checkpoint every 100 steps
)
Initialize and Start Training
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
peft_config=lora_config,
train_dataset=dataset,
)
trainer.train()
The dataset in this example uses conversational format (with messages
column), so SFTTrainer
automatically applies Jamba’s chat template. For more information about supported dataset formats and advanced SFTTrainer features, see the TRL documentation.
For Jamba Large LoRA fine-tuning, we recommend using the qLoRA+FSDP approach detailed in the QLoRA section below, as it provides better memory efficiency for the larger model.
QLoRA Fine-tuning
QLoRA combines LoRA with 4-bit quantization, making it possible to fine-tune on a single 80GB GPU while maintaining good performance.
Prerequisites
Before starting QLoRA fine-tuning, install the required dependencies:
pip install trl transformers torch datasets peft bitsandbytes
Implementation
Initialize Tokenizer and Configure Quantization
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig
tokenizer = AutoTokenizer.from_pretrained("ai21labs/AI21-Jamba-Mini-1.7")
# Configure 4-bit quantization
quantization_config = BitsAndBytesConfig(
load_in_4bit=True, # Enable 4-bit quantization
bnb_4bit_quant_type="nf4", # Use NormalFloat 4-bit quantization
bnb_4bit_compute_dtype=torch.bfloat16, # Compute in bfloat16 for better stability
)
Load Model with Quantization
model = AutoModelForCausalLM.from_pretrained(
"ai21labs/AI21-Jamba-Mini-1.7",
device_map="auto", # Automatically distribute across available GPUs
quantization_config=quantization_config, # Apply 4-bit quantization
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
Configure LoRA Parameters
lora_config = LoraConfig(
r=8, # Rank of adaptation - controls trainable parameters
target_modules=[
"embed_tokens",
"x_proj", "in_proj", "out_proj", # mamba layers
"gate_proj", "up_proj", "down_proj", # mlp layers
"q_proj", "k_proj", "v_proj", "o_proj", # attention layers
],
task_type="CAUSAL_LM",
bias="none",
)
Prepare Your Dataset
# Load dataset (replace with your own dataset)
dataset = load_dataset("philschmid/dolly-15k-oai-style", split="train")
Configure Training Settings
training_args = SFTConfig(
output_dir="./results", # Where to save the model
logging_dir="./logs", # Where to save training logs
num_train_epochs=2, # Number of training epochs
per_device_train_batch_size=8, # Higher batch size possible with quantization
learning_rate=1e-5, # Learning rate for fine-tuning
logging_steps=1, # Log training metrics every step
gradient_checkpointing=True, # Save memory at cost of compute
gradient_checkpointing_kwargs={"use_reentrant": False}, # Required for some models
save_steps=100, # Save model checkpoint every 100 steps
max_seq_length=4096, # Maximum sequence length
)
Initialize and Start Training
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
peft_config=lora_config,
train_dataset=dataset,
)
trainer.train()
Initialize Tokenizer and Configure Quantization
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig
tokenizer = AutoTokenizer.from_pretrained("ai21labs/AI21-Jamba-Mini-1.7")
# Configure 4-bit quantization
quantization_config = BitsAndBytesConfig(
load_in_4bit=True, # Enable 4-bit quantization
bnb_4bit_quant_type="nf4", # Use NormalFloat 4-bit quantization
bnb_4bit_compute_dtype=torch.bfloat16, # Compute in bfloat16 for better stability
)
Load Model with Quantization
model = AutoModelForCausalLM.from_pretrained(
"ai21labs/AI21-Jamba-Mini-1.7",
device_map="auto", # Automatically distribute across available GPUs
quantization_config=quantization_config, # Apply 4-bit quantization
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
Configure LoRA Parameters
lora_config = LoraConfig(
r=8, # Rank of adaptation - controls trainable parameters
target_modules=[
"embed_tokens",
"x_proj", "in_proj", "out_proj", # mamba layers
"gate_proj", "up_proj", "down_proj", # mlp layers
"q_proj", "k_proj", "v_proj", "o_proj", # attention layers
],
task_type="CAUSAL_LM",
bias="none",
)
Prepare Your Dataset
# Load dataset (replace with your own dataset)
dataset = load_dataset("philschmid/dolly-15k-oai-style", split="train")
Configure Training Settings
training_args = SFTConfig(
output_dir="./results", # Where to save the model
logging_dir="./logs", # Where to save training logs
num_train_epochs=2, # Number of training epochs
per_device_train_batch_size=8, # Higher batch size possible with quantization
learning_rate=1e-5, # Learning rate for fine-tuning
logging_steps=1, # Log training metrics every step
gradient_checkpointing=True, # Save memory at cost of compute
gradient_checkpointing_kwargs={"use_reentrant": False}, # Required for some models
save_steps=100, # Save model checkpoint every 100 steps
max_seq_length=4096, # Maximum sequence length
)
Initialize and Start Training
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
peft_config=lora_config,
train_dataset=dataset,
)
trainer.train()
Jamba Large fine-tuning requires 8x A100/H100 80GB GPUs and uses qLoRA+FSDP. This approach uses axolotl framework with a modified transformers version to optimize memory usage.
Due to its size, in order to run the training on a single 8 GPU node, Jamba Large 1.7 has to be quantized. This can happen either at the start of the training job, or in a pre-process step. If you want to pre-quantize the model, you can do that easily using bitsandbytes (make sure to use bnb_4bit_quant_storage=torch.bfloat16
so you can use FSDP).
Install Dependencies
# Install axolotl and dependencies
git clone https://github.com/axolotl-ai-cloud/axolotl
cd axolotl
pip3 install packaging ninja
pip3 install -e '.[flash-attn,deepspeed]'
pip install bitsandbytes~=0.43.3
pip install trl
pip install peft~=0.12.0
pip install accelerate~=0.33.0
pip install mamba-ssm causal-conv1d>=1.2.0
pip install git+https://github.com/xgal/transformers@897f80665c37c531b7803f92655dbc9b3a593fe7
Pre-quantize Model (Optional)
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
model_name = "ai21labs/AI21-Jamba-Large-1.7"
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_quant_storage=torch.bfloat16,
)
quantized_model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.bfloat16,
quantization_config=quantization_config
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.save_pretrained('AI21-Jamba-Large-1.7-BNB-nf4-bf16')
quantized_model.save_pretrained('AI21-Jamba-Large-1.7-BNB-nf4-bf16')
Run Training
# Run training with axolotl (change base_model to pre-quantized model if using pre-quantization)
accelerate launch -m axolotl.cli.train examples/jamba/qlora_fsdp.yaml
For detailed configuration files and examples, visit the axolotl Jamba examples. The modified transformers version prevents excessive CPU RAM usage that would otherwise require over 1.6TB instead of the required 200GB.