Spaces:
Sleeping
Sleeping
import sys | |
import os | |
prj_root_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
sys.path.append(prj_root_path) | |
from code_interpreter.BaseCodeInterpreter import BaseCodeInterpreter | |
from utils.const import * | |
from typing import List, Tuple, Dict | |
import re | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
sys.path.append(os.path.dirname(__file__)) | |
sys.path.append(os.path.dirname(os.path.abspath(__file__))) | |
import warnings | |
warnings.filterwarnings("ignore", category=UserWarning, module="transformers") | |
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" | |
class OpenCodeInterpreter(BaseCodeInterpreter): | |
def __init__( | |
self, | |
model_path: str, | |
load_in_8bit: bool = False, | |
load_in_4bit: bool = False, | |
): | |
# build tokenizer | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
model_path, | |
padding_side="right", | |
trust_remote_code=True | |
) | |
self.model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
device_map="auto", | |
load_in_4bit=load_in_4bit, | |
load_in_8bit=load_in_8bit, | |
torch_dtype=torch.float16, | |
trust_remote_code=True | |
) | |
self.model.resize_token_embeddings(len(self.tokenizer)) | |
self.model = self.model.eval() | |
self.dialog = [] | |
self.MAX_CODE_OUTPUT_LENGTH = 1000 | |
def dialog_to_prompt(self, dialog: List[Dict]) -> str: | |
full_str = self.tokenizer.apply_chat_template(dialog, tokenize=False) | |
return full_str | |
def extract_code_blocks(self, prompt: str) -> Tuple[bool, str]: | |
pattern = re.escape("```python") + r"(.*?)" + re.escape("```") | |
matches = re.findall(pattern, prompt, re.DOTALL) | |
if matches: | |
# Return the last matched code block | |
return True, matches[-1].strip() | |
else: | |
return False, "" | |
def clean_code_output(self, output: str) -> str: | |
if self.MAX_CODE_OUTPUT_LENGTH < len(output): | |
return ( | |
output[: self.MAX_CODE_OUTPUT_LENGTH // 5] | |
+ "\n...(truncated due to length)...\n" | |
+ output[-self.MAX_CODE_OUTPUT_LENGTH // 5 :] | |
) | |
return output | |