Logit Bias

🐶 What is it?

Would you like to exclude certain words or filter specific tokens from your generated text?

Logit biases can be used to promote or suppress the generation of specific tokens. This is accomplished by adding a bias term to each token's respective logits. Where a positive bias increases the generation probability, a negative bias decreases it.

Note that logit bias operates at the token level, so you must refer to valid tokens in the Jurassic-2 vocabulary, otherwise the API returns an error. Watch out for whitespaces, which are replaced with a special underscore character in our string representation of tokens (see here).

✍️ How to do it?

Pass the logitBias parameter. This is a dictionary which contains mapping from strings to floats, where the strings are text representations of the tokens and the floats are the biases themselves. A positive bias increases generation probability for a given token and a negative bias decreases it.

Consider the following call:

from ai21 import AI21Client

client = AI21Client(
    # This is the default and can be omitted
    api_key=os.environ.get("AI21_API_KEY"),
)

res = client.completion.create(
  model="j2-ultra",
  prompt="Life is like",
  temperature=0,
  max_tokens=3,
)

print(res.completions[0].data.text)

The response will be:

 a box of chocolates

Now, let's introduce a large negative bias to avoid generating the expected continuation "a box of":

res = client.completion.create(
  model="j2-ultra",
  prompt="Life is like",
  temperature=0,
  max_tokens=3,
  logit_bias={"▁a▁box▁of": -100.0},
)

print(res.completions[0].data.text)

The response will be:

 riding a bicycle