Overview
Jamba models support several quantization techniques:
- FP8 Quantization: 8-bit floating point weights for reduced memory footprint and efficient deployment
- ExpertsInt8: Innovative quantization for MoE models in vLLM deployment
- 8-bit Quantization: Using BitsAndBytesConfig for training and inference
FP8 Quantization (vLLM)
These models leverage pre-quantized FP8 weights, significantly reducing storage requirements and memory footprint while not compromising output quality.
FP8 quantization requires Hopper architecture GPUs such as NVIDIA H100 and NVIDIA H200.
Pre-quantized Model Weights
Prerequisites
pip install vllm>=0.6.5,<=0.8.5.post1
Implementation
Load Pre-quantized FP8 Model
from vllm import LLM, SamplingParams
llm = LLM(
model="ai21labs/AI21-Jamba-Mini-1.7-FP8",
max_model_len=100*1024,
)
Generate Text
sampling_params = SamplingParams(
temperature=0.4,
top_p=1.0,
max_tokens=100
)
prompts = ["Explain the advantages of FP8 quantization:"]
outputs = llm.generate(prompts, sampling_params)
print(outputs[0].outputs[0].text)
Pre-quantized FP8 models require no additional quantization parameters since the weights are already quantized.
ExpertsInt8 Quantization (vLLM)
ExpertsInt8 is an innovative and efficient quantization technique developed specifically for Mixture of Experts (MoE) models deployed in vLLM, including Jamba models. This technique enables:
- Jamba Mini 1.7: Deploy on a single 80GB GPU
- Jamba Large 1.7: Deploy on a single node of 8x 80GB GPUs
Prerequisites
pip install vllm>=0.6.5,<=0.8.5.post1
Implementation
Load Model with ExpertsInt8
from vllm import LLM
llm = LLM(
model="ai21labs/AI21-Jamba-Mini-1.7",
max_model_len=100*1024,
quantization="experts_int8" # Enable ExpertsInt8 quantization
)
Generate Text
from vllm import SamplingParams
sampling_params = SamplingParams(
temperature=0.4,
top_p=0.95,
max_tokens=100
)
# Generate text
prompts = ["Explain the benefits of model quantization:"]
outputs = llm.generate(prompts, sampling_params)
print(outputs[0].outputs[0].text)
With ExpertsInt8 quantization, you can fit prompts up to 100K tokens on a single 80GB A100 GPU with Jamba Mini.
8-bit Quantization (Hugging Face)
With 8-bit quantization using BitsAndBytesConfig, it is possible to fit up to 140K sequence length on a single 80GB GPU.
Prerequisites
pip install transformers torch bitsandbytes accelerate
Implementation
Configure 8-bit Quantization
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_8bit=True, # Enable 8-bit quantization
llm_int8_skip_modules=["mamba"] # Exclude Mamba blocks to preserve quality
)
Load Model with Quantization
model = AutoModelForCausalLM.from_pretrained(
"ai21labs/AI21-Jamba-Mini-1.7",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
quantization_config=quantization_config
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("ai21labs/AI21-Jamba-Mini-1.7")
Run Inference
messages = [
{"role": "system", "content": "You are a helpful AI assistant."},
{"role": "user", "content": "What are the advantages of 8-bit quantization?"}
]
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors='pt'
).to(model.device)
with torch.no_grad():
outputs = model.generate(
input_ids,
max_new_tokens=200,
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)
To maintain model quality, we recommend excluding Mamba blocks from quantization using llm_int8_skip_modules=["mamba"]
.