File size: 2,012 Bytes
3943768 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import base64
from enums import unknown_prompt_type, template_prompt_type
def get_use_chat_template(tokenizer, prompt_type=None):
if tokenizer is None:
return False
use_chat_template = prompt_type in [None, '', unknown_prompt_type, template_prompt_type] and \
has_chat_template(tokenizer)
return use_chat_template
def has_chat_template(tokenizer):
return (hasattr(tokenizer, 'chat_template') and
tokenizer.chat_template not in [None, ''] or
hasattr(tokenizer, 'default_chat_template') and
tokenizer.default_chat_template not in [None, '']
)
def get_chat_template(tokenizer):
if tokenizer is None:
return None
if hasattr(tokenizer, 'chat_template') and tokenizer.chat_template not in [None, '']:
return tokenizer.chat_template
if hasattr(tokenizer, 'default_chat_template') and tokenizer.default_chat_template not in [None, '']:
return tokenizer.default_chat_template
return None
def base64_encode_jinja_template(template_str):
encoded_bytes = base64.b64encode(template_str.encode('utf-8'))
encoded_str = encoded_bytes.decode('utf-8')
return encoded_str
def base64_decode_jinja_template(encoded_str):
if is_base64(encoded_str):
decoded_bytes = base64.b64decode(encoded_str.encode('utf-8'))
decoded_str = decoded_bytes.decode('utf-8')
return decoded_str
else:
# just normal string, pass along
return encoded_str
def is_base64(s):
# Check if the length is a multiple of 4
if len(s) % 4 != 0:
return False
# Check if the string contains only valid base64 characters
try:
# Try to decode the base64 string
decoded = base64.b64decode(s, validate=True)
# Check if the decoded bytes can be converted to a UTF-8 string
decoded.decode('utf-8')
except Exception:
return False
return True
|