svg-editor / watsonx_utils.py
remiserra's picture
init
37c29c8
# Remi Serra 202407
from ibm_watsonx_ai import APIClient, Credentials
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames
from ibm_watsonx_ai.foundation_models import Model
class wxEngine:
def __init__(self, apiendpoint, apikey, projectid):
# https://ibm.github.io/watsonx-ai-python-sdk/base.html#credentials.Credentials
self.credentials = Credentials(url=apiendpoint, api_key=apikey)
self.projectid = projectid
self.client = APIClient(credentials=self.credentials, project_id=self.projectid)
self.model = None
def init_model(self, modelid):
if self.model is None or (
modelid is not None and modelid != self.model.model_id
):
self.model = Model(
model_id=modelid,
params={},
credentials=self.credentials,
project_id=self.projectid,
)
return self.model
def list_models(self):
TextModels_enum = self.client.foundation_models.TextModels
model_ids = [e.value for e in TextModels_enum]
# print(f"list_models:models:{models}")
return model_ids
def get_model_max_tokens(self, modelid):
model_specs = self.client.foundation_models.get_model_specs(model_id=modelid)
# print(f"get_model_specs:model_specs:{model_specs}")
max_tokens = model_specs["model_limits"]["max_sequence_length"]
print(f"get_model_specs:max_tokens:{max_tokens}")
return max_tokens
def get_prompt_nb_tokens(self, prompt, modelid):
# https://ibm.github.io/watson-machine-learning-sdk/model.html#ibm_watson_machine_learning.foundation_models.Model.tokenize
self.init_model(modelid)
prompt_embeddings = self.model.tokenize(prompt, return_tokens=False)
# print(f"get_prompt_nb_tokens:prompt_embeddings:{prompt_embeddings}")
prompt_len = prompt_embeddings["result"]["token_count"]
print(f"get_prompt_nb_tokens:prompt_len:{prompt_len}")
return prompt_len
def generate_text(
self,
prompt: str,
modelid=None,
decoding="greedy",
temperature=1.0,
seed=42,
min_new_tokens=50,
max_new_tokens=200,
stop_sequences=[],
repetition_penalty=1,
stream=False,
):
self.init_model(modelid)
# Set the truncate_input_tokens to a value that is equal to or less than the maximum allowed tokens for the model that you are using. If you don't specify this value and the input has more tokens than the model can process, an error is generated.
truncate_input_tokens = int(self.get_model_max_tokens(modelid) * 0.9)
wmlparams = {
GenTextParamsMetaNames.DECODING_METHOD: decoding,
GenTextParamsMetaNames.TEMPERATURE: temperature,
GenTextParamsMetaNames.TOP_P: 0.2,
GenTextParamsMetaNames.TOP_K: 1,
GenTextParamsMetaNames.RANDOM_SEED: seed,
GenTextParamsMetaNames.MIN_NEW_TOKENS: min_new_tokens,
GenTextParamsMetaNames.MAX_NEW_TOKENS: max_new_tokens,
GenTextParamsMetaNames.STOP_SEQUENCES: stop_sequences,
GenTextParamsMetaNames.REPETITION_PENALTY: repetition_penalty,
GenTextParamsMetaNames.TRUNCATE_INPUT_TOKENS: truncate_input_tokens,
}
if stream:
return self.model.generate_text_stream(prompt, wmlparams)
else:
return self.model.generate_text(prompt, wmlparams)
## TEST
if __name__ == "__main__":
engine = wxEngine(
"https://us-south.ml.cloud.ibm.com", "xxx_apikey_xxxx", "xxxx-projectid-xxxx"
)
print(f"Test generate: {engine.generate_text('Who is Einstein ?')}")
print(
f"Test generate_stream: {engine.generate_text('Who is Einstein ?', stream=True)}"
)