"""This module contains functionalities for running inference on Gemma 2 model finetuned for urgency detection using the HuggingFace library. """ # Standard Library import ast from textwrap import dedent from typing import Any, Optional # Third Party Library import torch from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase def _construct_prompt(*, rules_list: list[str]) -> str: """Construct the prompt for the finetuned model. Parameters ---------- rules_list The list of urgency rules to match against the user message. Returns ------- str The prompt for the finetuned model. """ _prompt_base: str = dedent( """ You are a highly sensitive urgency detector. Score if ANY part of the user message corresponds to any part of the urgency rules provided below. Ignore any part of the user message that does not correspond to the rules. Respond with (a) the rule that is most consistent with the user message, (b) the probability between 0 and 1 with increments of 0.1 that ANY part of the user message matches the rule, and (c) the reason for the probability. Respond in json string: { best_matching_rule: str probability: float reason: str } """ ).strip() _prompt_rules: str = dedent( """ Urgency Rules: {urgency_rules} """ ).strip() urgency_rules_str = "\n".join( [f"{i}. {rule}" for i, rule in enumerate(rules_list, 1)] ) prompt = ( _prompt_base + "\n\n" + _prompt_rules.format(urgency_rules=urgency_rules_str) ) return prompt def get_completions( *, model, rules_list: list[str], skip_special_tokens_during_decode: bool = False, text_generation_params: Optional[dict[str, Any]] = None, tokenizer: PreTrainedTokenizerBase, user_message: str, ) -> dict[str, Any]: """Get completions from the model for the given data. Parameters ---------- model The model for inference. rules_list The list of urgency rules to match against the user message. skip_special_tokens_during_decode Specifies whether to skip special tokens during the decoding process. text_generation_params Dictionary containing text generation parameters for the LLM model. If not specified, then default values will be used. tokenizer The tokenizer for the model. user_message The user message to match against the urgency rules. Returns ------- dict[str, Any] The completion from the model. If the model output does not produce a valid JSON string, then the original output is returned in the "generated_json" key. """ assert all(x for x in rules_list), "Rules must be non-empty strings!" text_generation_params = text_generation_params or { "do_sample": True, "eos_token_id": tokenizer.eos_token_id, "max_new_tokens": 1024, "num_return_sequences": 1, "repetition_penalty": 1.1, "temperature": 1e-6, "top_p": 0.9, } tokenizer.add_special_tokens = False # Because we are using the chat template start_of_turn, end_of_turn = tokenizer.additional_special_tokens eos = tokenizer.eos_token start_of_turn_model = f"{start_of_turn}model" end_of_turn_model = f"{end_of_turn}{eos}" input_ = ( _construct_prompt(rules_list=rules_list) + f"\n\nUser Message:\n{user_message}" ) chat = [{"role": "user", "content": input_}] prompt = tokenizer.apply_chat_template( chat, add_generation_prompt=True, tokenize=False ) inputs = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt") outputs = model.generate( input_ids=inputs.to(model.device), **text_generation_params ) decoded_output = tokenizer.decode( outputs[0], skip_special_tokens=skip_special_tokens_during_decode ) completion_dict = {"user_message": user_message, "generated_json": decoded_output} try: start_of_turn_model_index = decoded_output.index(start_of_turn_model) end_of_turn_model_index = decoded_output.index(end_of_turn_model) generated_response = decoded_output[ start_of_turn_model_index + len(start_of_turn_model) : end_of_turn_model_index ].strip() completion_dict["generated_json"] = ast.literal_eval(generated_response) except (SyntaxError, ValueError): pass return completion_dict if __name__ == "__main__": DTYPE = torch.bfloat16 MODEL_ID = "idinsight/gemma-2-2b-it-ud" tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, add_eos_token=False) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" model = AutoModelForCausalLM.from_pretrained( MODEL_ID, device_map="auto", return_dict=True, torch_dtype=DTYPE ) text_generation_params = { "do_sample": True, "eos_token_id": tokenizer.eos_token_id, "max_new_tokens": 1024, "num_return_sequences": 1, "repetition_penalty": 1.1, "temperature": 1e-6, "top_p": 0.9, } response = get_completions( model=model, rules_list=[ "NOT URGENT", "Bleeding from the vagina", "Bad tummy pain", "Bad headache that won’t go away", "Bad headache that won’t go away", "Changes to vision", "Trouble breathing", "Hot or very cold, and very weak", "Fits or uncontrolled shaking", "Baby moves less", "Fluid from the vagina", "Feeding problems", "Fits or uncontrolled shaking", "Fast, slow or difficult breathing", "Too hot or cold", "Baby’s colour changes", "Vomiting and watery poo", "Infected belly button", "Swollen or infected eyes", "Bulging or sunken soft spot", ], skip_special_tokens_during_decode=False, text_generation_params=text_generation_params, tokenizer=tokenizer, user_message="If my newborn can't able to breathe what can i do", ) print(f"{response = }")