Spaces:
Running
Running
# Remi Serra 202407 | |
from ibm_watsonx_ai import APIClient, Credentials | |
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames | |
from ibm_watsonx_ai.foundation_models import Model | |
class wxEngine: | |
def __init__(self, apiendpoint, apikey, projectid): | |
# https://ibm.github.io/watsonx-ai-python-sdk/base.html#credentials.Credentials | |
self.credentials = Credentials(url=apiendpoint, api_key=apikey) | |
self.projectid = projectid | |
self.client = APIClient(credentials=self.credentials, project_id=self.projectid) | |
self.model = None | |
def init_model(self, modelid): | |
if self.model is None or ( | |
modelid is not None and modelid != self.model.model_id | |
): | |
self.model = Model( | |
model_id=modelid, | |
params={}, | |
credentials=self.credentials, | |
project_id=self.projectid, | |
) | |
return self.model | |
def list_models(self): | |
TextModels_enum = self.client.foundation_models.TextModels | |
model_ids = [e.value for e in TextModels_enum] | |
# print(f"list_models:models:{models}") | |
return model_ids | |
def get_model_max_tokens(self, modelid): | |
model_specs = self.client.foundation_models.get_model_specs(model_id=modelid) | |
# print(f"get_model_specs:model_specs:{model_specs}") | |
max_tokens = model_specs["model_limits"]["max_sequence_length"] | |
print(f"get_model_specs:max_tokens:{max_tokens}") | |
return max_tokens | |
def get_prompt_nb_tokens(self, prompt, modelid): | |
# https://ibm.github.io/watson-machine-learning-sdk/model.html#ibm_watson_machine_learning.foundation_models.Model.tokenize | |
self.init_model(modelid) | |
prompt_embeddings = self.model.tokenize(prompt, return_tokens=False) | |
# print(f"get_prompt_nb_tokens:prompt_embeddings:{prompt_embeddings}") | |
prompt_len = prompt_embeddings["result"]["token_count"] | |
print(f"get_prompt_nb_tokens:prompt_len:{prompt_len}") | |
return prompt_len | |
def generate_text( | |
self, | |
prompt: str, | |
modelid=None, | |
decoding="greedy", | |
temperature=1.0, | |
seed=42, | |
min_new_tokens=50, | |
max_new_tokens=200, | |
stop_sequences=[], | |
repetition_penalty=1, | |
stream=False, | |
): | |
self.init_model(modelid) | |
# Set the truncate_input_tokens to a value that is equal to or less than the maximum allowed tokens for the model that you are using. If you don't specify this value and the input has more tokens than the model can process, an error is generated. | |
truncate_input_tokens = int(self.get_model_max_tokens(modelid) * 0.9) | |
wmlparams = { | |
GenTextParamsMetaNames.DECODING_METHOD: decoding, | |
GenTextParamsMetaNames.TEMPERATURE: temperature, | |
GenTextParamsMetaNames.TOP_P: 0.2, | |
GenTextParamsMetaNames.TOP_K: 1, | |
GenTextParamsMetaNames.RANDOM_SEED: seed, | |
GenTextParamsMetaNames.MIN_NEW_TOKENS: min_new_tokens, | |
GenTextParamsMetaNames.MAX_NEW_TOKENS: max_new_tokens, | |
GenTextParamsMetaNames.STOP_SEQUENCES: stop_sequences, | |
GenTextParamsMetaNames.REPETITION_PENALTY: repetition_penalty, | |
GenTextParamsMetaNames.TRUNCATE_INPUT_TOKENS: truncate_input_tokens, | |
} | |
if stream: | |
return self.model.generate_text_stream(prompt, wmlparams) | |
else: | |
return self.model.generate_text(prompt, wmlparams) | |
## TEST | |
if __name__ == "__main__": | |
engine = wxEngine( | |
"https://us-south.ml.cloud.ibm.com", "xxx_apikey_xxxx", "xxxx-projectid-xxxx" | |
) | |
print(f"Test generate: {engine.generate_text('Who is Einstein ?')}") | |
print( | |
f"Test generate_stream: {engine.generate_text('Who is Einstein ?', stream=True)}" | |
) | |