> ## Documentation Index
> Fetch the complete documentation index at: https://docs.ai21.com/llms.txt
> Use this file to discover all available pages before exploring further.

# Fine-tuning

> Fine-tuning is the process of adapting a pre-trained model to perform better on specific tasks by training it on domain-specific data. Learn how to fine-tune Jamba models using different approaches including full fine-tuning, LoRA, and QLoRA

## Overview

The Jamba models can be fine-tuned using several approaches:

* **Full Fine-tuning**: Complete model parameter updates (requires significant GPU resources)
* **LoRA (Low-Rank Adaptation)**: Parameter-efficient fine-tuning approach
* **QLoRA**: Combines LoRA with 4-bit quantization for single GPU training

## Full Fine-tuning

Full fine-tuning updates all model parameters and provides the most comprehensive training results.

For a comprehensive implementation guide using AWS SageMaker with multi-node and FSDP configuration, see the [AI21 SageMaker Fine-tuning Repository](https://github.com/AI21Labs/hf-finetune-sagemaker).

<Note>
  Full fine-tuning requires multiple high-memory GPUs.
</Note>

## LoRA Fine-tuning

LoRA (Low-Rank Adaptation) fine-tuning injects compact, low-rank adapter layers into a frozen pretrained model—letting you specialize it for your task with just a few percent of the parameters, minimal extra compute and storage and with a small loss in accuracy or inference speed.

### Prerequisites

Before starting LoRA fine-tuning, install the required dependencies:

```bash theme={"system"}
pip install trl transformers torch datasets peft
```

<Note>
  This LoRA fine-tuning example uses bfloat16 precision and requires \~130GB GPU RAM (e.g., 2x A100 80GB GPUs).
</Note>

### Implementation

<Tabs>
  <Tab title="Jamba Mini">
    <Steps>
      <Step title="Load Model and Tokenizer">
        ```python theme={"system"}
        import torch
        from transformers import AutoTokenizer, AutoModelForCausalLM
        from datasets import load_dataset
        from trl import SFTTrainer, SFTConfig
        from peft import LoraConfig

        tokenizer = AutoTokenizer.from_pretrained("ai21labs/AI21-Jamba-Mini-1.7")
        model = AutoModelForCausalLM.from_pretrained(
            "ai21labs/AI21-Jamba-Mini-1.7",
            device_map="auto",  # Automatically distribute across available GPUs
            torch_dtype=torch.bfloat16,  # Use mixed precision for memory efficiency
            attn_implementation="flash_attention_2",  # Optimized attention implementation
        )
        ```
      </Step>

      <Step title="Configure LoRA Parameters">
        ```python theme={"system"}
        lora_config = LoraConfig(
            r=8,  # Rank of adaptation - controls the number of trainable parameters
            target_modules=[
                "embed_tokens",
                "x_proj", "in_proj", "out_proj",  # mamba layers
                "gate_proj", "up_proj", "down_proj",  # mlp layers
                "q_proj", "k_proj", "v_proj", "o_proj",  # attention layers
            ],
            task_type="CAUSAL_LM",
            bias="none",
        )
        ```
      </Step>

      <Step title="Prepare Your Dataset">
        ```python theme={"system"}
        # Load dataset (replace with your own dataset)
        dataset = load_dataset("philschmid/dolly-15k-oai-style", split="train")
        ```
      </Step>

      <Step title="Configure Training Settings">
        ```python theme={"system"}
        training_args = SFTConfig(
            output_dir="/dev/shm/results",  # Where to save the model
            logging_dir="./logs",  # Where to save training logs
            num_train_epochs=2,  # Number of training epochs
            per_device_train_batch_size=4,  # Batch size per GPU
            learning_rate=1e-5,  # Learning rate for fine-tuning
            logging_steps=10,  # Log training metrics every 10 steps
            gradient_checkpointing=True,  # Save memory at cost of compute
            max_seq_length=4096,  # Maximum sequence length
            save_steps=100,  # Save model checkpoint every 100 steps
        )
        ```
      </Step>

      <Step title="Initialize and Start Training">
        ```python theme={"system"}
        trainer = SFTTrainer(
            model=model,
            tokenizer=tokenizer,
            args=training_args,
            peft_config=lora_config,
            train_dataset=dataset,
        )

        trainer.train()
        ```
      </Step>
    </Steps>

    <Info>
      The dataset in this example uses conversational format (with `messages` column), so `SFTTrainer` automatically applies Jamba's chat template. For more information about supported dataset formats and advanced SFTTrainer features, see the [TRL documentation](https://huggingface.co/docs/trl/main/en/sft_trainer#dataset-format-support).
    </Info>
  </Tab>

  <Tab title="Jamba Large">
    <Note>
      For Jamba Large LoRA fine-tuning, we recommend using the qLoRA+FSDP approach detailed in the QLoRA section below, as it provides better memory efficiency for the larger model.
    </Note>
  </Tab>
</Tabs>

## QLoRA Fine-tuning

[QLoRA](https://arxiv.org/abs/2305.14314) combines LoRA with 4-bit quantization, making it possible to fine-tune on a single 80GB GPU while maintaining good performance.

### Prerequisites

Before starting QLoRA fine-tuning, install the required dependencies:

```bash theme={"system"}
pip install trl transformers torch datasets peft bitsandbytes
```

### Implementation

<Tabs>
  <Tab title="Jamba Mini">
    <Steps>
      <Step title="Initialize Tokenizer and Configure Quantization">
        ```python theme={"system"}
        import torch
        from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
        from datasets import load_dataset
        from trl import SFTTrainer, SFTConfig
        from peft import LoraConfig

        tokenizer = AutoTokenizer.from_pretrained("ai21labs/AI21-Jamba-Mini-1.7")

        # Configure 4-bit quantization
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,  # Enable 4-bit quantization
            bnb_4bit_quant_type="nf4",  # Use NormalFloat 4-bit quantization
            bnb_4bit_compute_dtype=torch.bfloat16,  # Compute in bfloat16 for better stability
        )
        ```
      </Step>

      <Step title="Load Model with Quantization">
        ```python theme={"system"}
        model = AutoModelForCausalLM.from_pretrained(
            "ai21labs/AI21-Jamba-Mini-1.7",
            device_map="auto",  # Automatically distribute across available GPUs
            quantization_config=quantization_config,  # Apply 4-bit quantization
            torch_dtype=torch.bfloat16,
            attn_implementation="flash_attention_2",
        )
        ```
      </Step>

      <Step title="Configure LoRA Parameters">
        ```python theme={"system"}
        lora_config = LoraConfig(
            r=8,  # Rank of adaptation - controls trainable parameters
            target_modules=[
                "embed_tokens", 
                "x_proj", "in_proj", "out_proj",  # mamba layers
                "gate_proj", "up_proj", "down_proj",  # mlp layers
                "q_proj", "k_proj", "v_proj", "o_proj",  # attention layers
            ],
            task_type="CAUSAL_LM",
            bias="none",
        )
        ```
      </Step>

      <Step title="Prepare Your Dataset">
        ```python theme={"system"}
        # Load dataset (replace with your own dataset)
        dataset = load_dataset("philschmid/dolly-15k-oai-style", split="train")
        ```
      </Step>

      <Step title="Configure Training Settings">
        ```python theme={"system"}
        training_args = SFTConfig(
            output_dir="./results",  # Where to save the model
            logging_dir="./logs",  # Where to save training logs
            num_train_epochs=2,  # Number of training epochs
            per_device_train_batch_size=8,  # Higher batch size possible with quantization
            learning_rate=1e-5,  # Learning rate for fine-tuning
            logging_steps=1,  # Log training metrics every step
            gradient_checkpointing=True,  # Save memory at cost of compute
            gradient_checkpointing_kwargs={"use_reentrant": False},  # Required for some models
            save_steps=100,  # Save model checkpoint every 100 steps
            max_seq_length=4096,  # Maximum sequence length
        )
        ```
      </Step>

      <Step title="Initialize and Start Training">
        ```python theme={"system"}
        trainer = SFTTrainer(
            model=model,
            tokenizer=tokenizer,
            args=training_args,
            peft_config=lora_config,
            train_dataset=dataset,
        )

        trainer.train()
        ```
      </Step>
    </Steps>
  </Tab>

  <Tab title="Jamba Large">
    <Note>
      Jamba Large fine-tuning requires 8x A100/H100 80GB GPUs and uses qLoRA+FSDP. This approach uses axolotl framework with a modified transformers version to optimize memory usage.
    </Note>

    <Info>
      Due to its size, in order to run the training on a single 8 GPU node, Jamba Large 1.7 has to be quantized. This can happen either at the start of the training job, or in a pre-process step. If you want to pre-quantize the model, you can do that easily using bitsandbytes (make sure to use `bnb_4bit_quant_storage=torch.bfloat16` so you can use FSDP).
    </Info>

    <Steps>
      <Step title="Install Dependencies">
        ```bash theme={"system"}
        # Install axolotl and dependencies
        git clone https://github.com/axolotl-ai-cloud/axolotl
        cd axolotl
        pip3 install packaging ninja
        pip3 install -e '.[flash-attn,deepspeed]'

        pip install bitsandbytes~=0.43.3
        pip install trl
        pip install peft~=0.12.0
        pip install accelerate~=0.33.0
        pip install mamba-ssm causal-conv1d>=1.2.0
        pip install git+https://github.com/xgal/transformers@897f80665c37c531b7803f92655dbc9b3a593fe7
        ```
      </Step>

      <Step title="Pre-quantize Model (Optional)">
        ```python theme={"system"}
        from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
        import torch

        model_name = "ai21labs/AI21-Jamba-Large-1.7"
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_quant_storage=torch.bfloat16,
        )

        quantized_model = AutoModelForCausalLM.from_pretrained(
            model_name, 
            device_map="auto", 
            torch_dtype=torch.bfloat16, 
            quantization_config=quantization_config
        )
            
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        tokenizer.save_pretrained('AI21-Jamba-Large-1.7-BNB-nf4-bf16')
        quantized_model.save_pretrained('AI21-Jamba-Large-1.7-BNB-nf4-bf16')
        ```
      </Step>

      <Step title="Run Training">
        ```bash theme={"system"}
        # Run training with axolotl (change base_model to pre-quantized model if using pre-quantization)
        accelerate launch -m axolotl.cli.train examples/jamba/qlora_fsdp.yaml
        ```
      </Step>
    </Steps>

    <Info>
      For detailed configuration files and examples, visit the [axolotl Jamba examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/jamba). The modified transformers version prevents excessive CPU RAM usage that would otherwise require over 1.6TB instead of the required 200GB.
    </Info>
  </Tab>
</Tabs>
