Python SDK - with Amazon SageMaker

AI21 Studio Python SDK with Amazon SageMaker Guide

This guide covers how to use the AI21 Studio Python SDK with SageMaker integration for interaction with Jurassic-2 models.

📘

The Python SDK for SageMaker currently works only with Jurassic-2 models. We will be upgrading to Jamba soon.

Setup

To get started with AI21's SDK with SageMaker integration, you'll need to install it with the AWS extra first. To do so, simply run the following command:

$ pip install -U "ai21[AWS]>=2.1.1"

Using AI21 Studio Python SDK with SageMaker

To use the AI21 Studio Python SDK with SageMaker:

  1. Create an Amazon SageMaker instance and configure it.
  2. Install the AI21 Studio Python SDK with SageMaker integration using the command mentioned above.
  3. After setting up the environment, you can follow the same steps as described in the samples bAI21 Studio Python SDK Guide for Jurassic-2 models.

Examples

Completion: Using AI21 Studio Python SDK with SageMaker for Jurassic-2

Below is a sample usage of the AI21 Python SDK with SageMaker integration to interact with Jurassic-2 models:

from ai21 import AI21SageMakerClient

client = AI21SageMakerClient(endpoint_name="<your_endpoint_name")

# J2 Mid
response_mid = client.completion.create(
  prompt="explain black holes to 8th graders",
  num_results=1,
  max_tokens=2,
  temperature=0.7,
)

print(response_mid)

By customizing the request parameters, you can control the content and style of the generated text. For a full list of available options, check out our Complete API page.

All Jurassic-2 models can be interacted with using the same ai21.Completion.execute() function, including j2-light, j2-mid, and j2-ultra.

Response

Here's an example of a response object from executing a j2-mid model:

{
   "id":"94078cb6-687e-4262-ef8f-1d7c2b0dbd2b",
   "prompt":{
      "text":"These are a few of my favorite",
      "tokens":[
         {
            "generatedToken":{
               "token":"▁These▁are",
               "logprob":-8.824776649475098,
               "raw_logprob":-8.824776649475098
            },
            "topTokens":"None",
            "textRange":{
               "start":0,
               "end":9
            }
         },
         {
            "generatedToken":{
               "token":"▁a▁few",
               "logprob":-4.798709869384766,
               "raw_logprob":-4.798709869384766
            },
            "topTokens":"None",
            "textRange":{
               "start":9,
               "end":15
            }
         },
         {
            "generatedToken":{
               "token":"▁of▁my▁favorite",
               "logprob":-1.0864331722259521,
               "raw_logprob":-1.0864331722259521
            },
            "topTokens":"None",
            "textRange":{
               "start":15,
               "end":30
            }
         }
      ]
   },
   "completions":[
      {
         "data":{
            "text":" things –",
            "tokens":[
               {
                  "generatedToken":{
                     "token":"▁things",
                     "logprob":-0.0003219324571546167,
                     "raw_logprob":-0.47372230887413025
                  },
                  "topTokens":"None",
                  "textRange":{
                     "start":0,
                     "end":7
                  }
               },
               {
                  "generatedToken":{
                     "token":"▁–",
                     "logprob":-7.797079563140869,
                     "raw_logprob":-4.319167613983154
                  },
                  "topTokens":"None",
                  "textRange":{
                     "start":7,
                     "end":9
                  }
               }
            ]
         },
         "finishReason":{
            "reason":"length",
            "length":2
         }
      }
   ]
}

The response is a nested data structure containing information about the processed request, prompt, and completions. At the top level, the response has the following fields:

ID

A unique string id for the processed request. Repeated identical requests receive different IDs.

prompt

The prompt includes the raw text, the tokens with their log probabilities, and the top-K alternative tokens at each position, if requested. It has two nested fields:

  • text (string)
  • tokens (list of TokenData)

completions

A list of completions, including raw text, tokens, and log probabilities. The number of completions corresponds to the requested numResults. Each completion has two nested fields:

  • data, which contains the text (string) and tokens (list of TokenData) for the completion.
  • finishReason, a nested data structure describing the reason generation was terminated for this completion.

TokenData

The TokenData object provides detailed information about each token in both the prompt and the completions. It includes the following fields:

generatedToken:

The generatedToken field consists of two nested fields:

  • token: The string representation of the token.
  • logprob: The predicted log probability of the token after applying the sampling parameters as a float value.
  • raw_logprob: The raw predicted log probability of the token as a float value. For the indifferent values (namely, temperature=1, topP=1) we get raw_logprob=logprob.

topTokens

The topTokens field is a list of the top K alternative tokens for this position, sorted by probability, according to the topKReturn request parameter. If topKReturn is set to 0, this field will be null.

Each token in the list includes:

  • token: The string representation of the alternative token.
  • logprob: The predicted log probability of the alternative token as a float value.

textRange

The textRange field indicates the start and end offsets of the token in the decoded text string:

  • start: The starting index of the token in the decoded text string.
  • end: The ending index of the token in the decoded text string.

List Model Package Versions

Parameters

  • model_name (str): Name of the model. Available options: 'j2-light', 'j2-mid', 'j2-ultra', 'gec', 'contextual-answers', 'paraphrase', 'summarize'.
  • region (str): The AWS region in which the model is deployed.
import ai21

model_versions = ai21.SageMaker.list_model_package_versions(model_name='j2-mid',region='us-east-1')

Response

(str): List of all the available versions for this model in the specified region.

Get Model Package ARN

Parameters

  • model_name (str): Name of the model. Available options: 'j2-light', 'j2-mid', 'j2-ultra', 'gec', 'contextual-answers', 'paraphrase', 'summarize'.
  • region (str): The AWS region in which the model is deployed.
  • version (str, optional): The version of the model package (default is "latest"). You can list_model_package_versions if you're not sure which versions are available
import ai21

model_package_arn = ai21.SageMaker.get_model_package_arn(
  model_name='j2-mid',
  region='us-east-1',
  version='2-2-000',
)

Response

(str): The generated Amazon Resource Name (ARN) for the specified model package.


Additional Resources

For a detailed example of using Jurassic-2 Mid on SageMaker through Model Packages, you can refer to this notebook on AI21 Labs' SageMaker GitHub repository.

By using the AI21 Studio Python SDK with SageMaker integration, you can seamlessly utilize AI21 Studio Jurassic-2 models in your SageMaker environment, allowing for streamlined development and deployment of your machine learning solutions. Note that the SDK version for SageMaker supports Jurassic-2 models but not the Task Specific Models. Additionally, there's no need to use an AI21 Studio API key when using the SDK version for SageMaker.