File size: 3,250 Bytes
b585c7f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import os
import typing
import json
from langchain.llms import SagemakerEndpoint
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from pydantic.v1 import root_validator
from src.utils import FakeTokenizer
class ChatContentHandler(LLMContentHandler):
content_type = "application/json"
accepts = "application/json"
def transform_input(self, prompt: str, model_kwargs: typing.Dict) -> bytes:
messages0 = []
openai_system_prompt = "You are a helpful assistant."
if openai_system_prompt:
messages0.append({"role": "system", "content": openai_system_prompt})
messages0.append({'role': 'user', 'content': prompt})
input_dict = {'inputs': [messages0], "parameters": model_kwargs}
return json.dumps(input_dict).encode("utf-8")
def transform_output(self, output: bytes) -> str:
response_json = json.loads(output.read().decode("utf-8"))
return response_json[0]["generation"]['content']
class BaseContentHandler(LLMContentHandler):
content_type = "application/json"
accepts = "application/json"
def transform_input(self, prompt: str, model_kwargs: typing.Dict) -> bytes:
input_dict = {'inputs': prompt, "parameters": model_kwargs}
return json.dumps(input_dict).encode("utf-8")
def transform_output(self, output: bytes) -> str:
response_json = json.loads(output.read().decode("utf-8"))
return response_json[0]["generation"]
class H2OSagemakerEndpoint(SagemakerEndpoint):
aws_access_key_id: str = ""
aws_secret_access_key: str = ""
tokenizer: typing.Any = None
@root_validator()
def validate_environment(cls, values: typing.Dict) -> typing.Dict:
"""Validate that AWS credentials to and python package exists in environment."""
try:
import boto3
try:
if values["credentials_profile_name"] is not None:
session = boto3.Session(
profile_name=values["credentials_profile_name"]
)
else:
# use default credentials
session = boto3.Session()
values["client"] = session.client(
"sagemaker-runtime",
region_name=values['region_name'],
aws_access_key_id=values['aws_access_key_id'],
aws_secret_access_key=values['aws_secret_access_key'],
)
except Exception as e:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that credentials in the specified "
"profile name are valid."
) from e
except ImportError:
raise ImportError(
"Could not import boto3 python package. "
"Please install it with `pip install boto3`."
)
return values
def get_token_ids(self, text: str) -> typing.List[int]:
tokenizer = self.tokenizer
if tokenizer is not None:
return tokenizer.encode(text)
else:
return FakeTokenizer().encode(text)['input_ids']
|