'
+ for i, _row in enumerate(history[::-1]):
+ row = [convert_to_markdown(entry) for entry in _row]
+
+ output += f"""
+
+ """
+
+ if len(row[0]) == 0: # don't display empty user messages
+ continue
+
+ output += f"""
+
+ """
+
+ output += "
"
+
+ return output
+
+
+def generate_cai_chat_html(history, name1, name2, reset_cache=False):
+ output = f''
+
+ # We use ?name2 and ?time.time() to force the browser to reset caches
+ img_bot = f'
' if Path("cache/pfp_character.png").exists() else ''
+ img_me = f'
' if Path("cache/pfp_me.png").exists() else ''
+
+ for i, _row in enumerate(history[::-1]):
+ row = [convert_to_markdown(entry) for entry in _row]
+
+ output += f"""
+
+
+ {img_bot}
+
+
+
+ {name2}
+
+
+ {row[1]}
+
+
+
+ """
+
+ if len(row[0]) == 0: # don't display empty user messages
+ continue
+
+ output += f"""
+
+
+ {img_me}
+
+
+
+ {name1}
+
+
+ {row[0]}
+
+
+
+ """
+
+ output += "
"
+ return output
+
+
+def generate_chat_html(history, name1, name2, reset_cache=False):
+ output = f''
+
+ for i, _row in enumerate(history[::-1]):
+ row = [convert_to_markdown(entry) for entry in _row]
+
+ output += f"""
+
+ """
+
+ if len(row[0]) == 0: # don't display empty user messages
+ continue
+
+ output += f"""
+
+ """
+
+ output += "
"
+ return output
+
+
+def chat_html_wrapper(history, name1, name2, mode, reset_cache=False):
+ if mode == "cai-chat":
+ return generate_cai_chat_html(history, name1, name2, reset_cache)
+ elif mode == "chat":
+ return generate_chat_html(history, name1, name2)
+ elif mode == "instruct":
+ return generate_instruct_html(history)
+ else:
+ return ''
diff --git a/text-generation-webui-main/modules/llama_attn_hijack.py b/text-generation-webui-main/modules/llama_attn_hijack.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5c5c92e40c115cb5dee77026605aab7b7711511
--- /dev/null
+++ b/text-generation-webui-main/modules/llama_attn_hijack.py
@@ -0,0 +1,176 @@
+import math
+import sys
+import torch
+import torch.nn as nn
+import transformers.models.llama.modeling_llama
+
+from typing import Optional
+from typing import Tuple
+
+import modules.shared as shared
+
+
+if shared.args.xformers:
+ try:
+ import xformers.ops
+ except Exception:
+ print("🔴 xformers not found! Please install it before trying to use it.", file=sys.stderr)
+
+
+def hijack_llama_attention():
+ if shared.args.xformers:
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
+ print("Replaced attention with xformers_attention")
+ elif shared.args.sdp_attention:
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = sdp_attention_forward
+ print("Replaced attention with sdp_attention")
+
+
+def xformers_forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+ # [bsz, nh, t, hd]
+
+ if past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ #We only apply xformers optimizations if we don't need to output the whole attention matrix
+ if not output_attentions:
+ dtype = query_states.dtype
+
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ #This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
+ #We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
+ if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
+ # input and output should be of form (bsz, q_len, num_heads, head_dim)
+ attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=None)
+ else:
+ # input and output should be of form (bsz, q_len, num_heads, head_dim)
+ attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=xformers.ops.LowerTriangularMask())
+ attn_weights = None
+ else:
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights + attention_mask
+ attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2)
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, attn_weights, past_key_value
+
+
+def sdp_attention_forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+ # [bsz, nh, t, hd]
+
+ if past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ #We only apply sdp attention if we don't need to output the whole attention matrix
+ if not output_attentions:
+ attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask, is_causal=False)
+ attn_weights = None
+ else:
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights + attention_mask
+ attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, attn_weights, past_key_value
diff --git a/text-generation-webui-main/modules/llamacpp_model.py b/text-generation-webui-main/modules/llamacpp_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..9461db109c0c172b97b11da075a9adcf30a12254
--- /dev/null
+++ b/text-generation-webui-main/modules/llamacpp_model.py
@@ -0,0 +1,82 @@
+import multiprocessing
+
+import llamacpp
+
+from modules import shared
+from modules.callbacks import Iteratorize
+
+
+class LlamaCppTokenizer:
+ """A thin wrapper over the llamacpp tokenizer"""
+ def __init__(self, model: llamacpp.LlamaInference):
+ self._tokenizer = model.get_tokenizer()
+ self.eos_token_id = 2
+ self.bos_token_id = 0
+
+ @classmethod
+ def from_model(cls, model: llamacpp.LlamaInference):
+ return cls(model)
+
+ def encode(self, prompt: str):
+ return self._tokenizer.tokenize(prompt)
+
+ def decode(self, ids):
+ return self._tokenizer.detokenize(ids)
+
+
+class LlamaCppModel:
+ def __init__(self):
+ self.initialized = False
+
+ @classmethod
+ def from_pretrained(self, path):
+ params = llamacpp.InferenceParams()
+ params.path_model = str(path)
+ params.n_threads = shared.args.threads or multiprocessing.cpu_count() // 2
+
+ _model = llamacpp.LlamaInference(params)
+
+ result = self()
+ result.model = _model
+ result.params = params
+
+ tokenizer = LlamaCppTokenizer.from_model(_model)
+ return result, tokenizer
+
+ def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, callback=None):
+ params = self.params
+ params.n_predict = token_count
+ params.top_p = top_p
+ params.top_k = top_k
+ params.temp = temperature
+ params.repeat_penalty = repetition_penalty
+ # params.repeat_last_n = repeat_last_n
+
+ # self.model.params = params
+ self.model.add_bos()
+ self.model.update_input(context)
+
+ output = ""
+ is_end_of_text = False
+ ctr = 0
+ while ctr < token_count and not is_end_of_text:
+ if self.model.has_unconsumed_input():
+ self.model.ingest_all_pending_input()
+ else:
+ self.model.eval()
+ token = self.model.sample()
+ text = self.model.token_to_str(token)
+ output += text
+ is_end_of_text = token == self.model.token_eos()
+ if callback:
+ callback(text)
+ ctr += 1
+
+ return output
+
+ def generate_with_streaming(self, **kwargs):
+ with Iteratorize(self.generate, kwargs, callback=None) as generator:
+ reply = ''
+ for token in generator:
+ reply += token
+ yield reply
diff --git a/text-generation-webui-main/modules/llamacpp_model_alternative.py b/text-generation-webui-main/modules/llamacpp_model_alternative.py
new file mode 100644
index 0000000000000000000000000000000000000000..2671f2273cce6c47348405cd6ea188eb075ad9fb
--- /dev/null
+++ b/text-generation-webui-main/modules/llamacpp_model_alternative.py
@@ -0,0 +1,65 @@
+'''
+Based on
+https://github.com/abetlen/llama-cpp-python
+
+Documentation:
+https://abetlen.github.io/llama-cpp-python/
+'''
+
+from llama_cpp import Llama, LlamaCache
+
+from modules import shared
+from modules.callbacks import Iteratorize
+
+
+class LlamaCppModel:
+ def __init__(self):
+ self.initialized = False
+
+ @classmethod
+ def from_pretrained(self, path):
+ result = self()
+
+ params = {
+ 'model_path': str(path),
+ 'n_ctx': 2048,
+ 'seed': 0,
+ 'n_threads': shared.args.threads or None,
+ 'n_batch': shared.args.n_batch
+ }
+ self.model = Llama(**params)
+ self.model.set_cache(LlamaCache)
+
+ # This is ugly, but the model and the tokenizer are the same object in this library.
+ return result, result
+
+ def encode(self, string):
+ if type(string) is str:
+ string = string.encode()
+ return self.model.tokenize(string)
+
+ def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, callback=None):
+ if type(context) is str:
+ context = context.encode()
+ tokens = self.model.tokenize(context)
+
+ output = b""
+ count = 0
+ for token in self.model.generate(tokens, top_k=top_k, top_p=top_p, temp=temperature, repeat_penalty=repetition_penalty):
+ text = self.model.detokenize([token])
+ output += text
+ if callback:
+ callback(text.decode())
+
+ count += 1
+ if count >= token_count or (token == self.model.token_eos()):
+ break
+
+ return output.decode()
+
+ def generate_with_streaming(self, **kwargs):
+ with Iteratorize(self.generate, kwargs, callback=None) as generator:
+ reply = ''
+ for token in generator:
+ reply += token
+ yield reply
diff --git a/text-generation-webui-main/modules/models.py b/text-generation-webui-main/modules/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..a17fba4b345047f5dea26b8527bd49b9267ae464
--- /dev/null
+++ b/text-generation-webui-main/modules/models.py
@@ -0,0 +1,288 @@
+import gc
+import json
+import os
+import re
+import time
+import zipfile
+from pathlib import Path
+
+import numpy as np
+import torch
+import transformers
+from accelerate import infer_auto_device_map, init_empty_weights
+from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
+ AutoTokenizer, BitsAndBytesConfig, LlamaTokenizer)
+
+import modules.shared as shared
+from modules import llama_attn_hijack
+
+transformers.logging.set_verbosity_error()
+
+if shared.args.flexgen:
+ from flexgen.flex_opt import CompressionConfig, ExecutionEnv, OptLM, Policy
+
+local_rank = None
+if shared.args.deepspeed:
+ import deepspeed
+ from transformers.deepspeed import (HfDeepSpeedConfig,
+ is_deepspeed_zero3_enabled)
+
+ from modules.deepspeed_parameters import generate_ds_config
+
+ # Distributed setup
+ local_rank = shared.args.local_rank if shared.args.local_rank is not None else int(os.getenv("LOCAL_RANK", "0"))
+ world_size = int(os.getenv("WORLD_SIZE", "1"))
+ torch.cuda.set_device(local_rank)
+ deepspeed.init_distributed()
+ ds_config = generate_ds_config(shared.args.bf16, 1 * world_size, shared.args.nvme_offload_dir)
+ dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration
+
+
+def find_model_type(model_name):
+ model_name = model_name.lower()
+ if 'rwkv-' in model_name:
+ return 'rwkv'
+ elif len(list(Path(f'{shared.args.model_dir}/{model_name}').glob('*ggml*.bin'))) > 0:
+ return 'llamacpp'
+ elif re.match('.*ggml.*\.bin', model_name):
+ return 'llamacpp'
+ elif 'chatglm' in model_name:
+ return 'chatglm'
+ elif 'galactica' in model_name:
+ return 'galactica'
+ elif 'llava' in model_name:
+ return 'llava'
+ elif any((k in model_name for k in ['gpt4chan', 'gpt-4chan'])):
+ return 'gpt4chan'
+ else:
+ return 'HF_generic'
+
+
+def load_model(model_name):
+ print(f"Loading {model_name}...")
+ t0 = time.time()
+
+ shared.model_type = find_model_type(model_name)
+ if shared.model_type == 'chatglm':
+ LoaderClass = AutoModel
+ trust_remote_code = shared.args.trust_remote_code
+ else:
+ LoaderClass = AutoModelForCausalLM
+ trust_remote_code = False
+
+ # Load the model in simple 16-bit mode by default
+ if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.model_type in ['rwkv', 'llamacpp']]):
+ model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16, trust_remote_code=trust_remote_code)
+ if torch.has_mps:
+ device = torch.device('mps')
+ model = model.to(device)
+ else:
+ model = model.cuda()
+
+ # FlexGen
+ elif shared.args.flexgen:
+ # Initialize environment
+ env = ExecutionEnv.create(shared.args.disk_cache_dir)
+
+ # Offloading policy
+ policy = Policy(1, 1,
+ shared.args.percent[0], shared.args.percent[1],
+ shared.args.percent[2], shared.args.percent[3],
+ shared.args.percent[4], shared.args.percent[5],
+ overlap=True, sep_layer=True, pin_weight=shared.args.pin_weight,
+ cpu_cache_compute=False, attn_sparsity=1.0,
+ compress_weight=shared.args.compress_weight,
+ comp_weight_config=CompressionConfig(
+ num_bits=4, group_size=64,
+ group_dim=0, symmetric=False),
+ compress_cache=False,
+ comp_cache_config=CompressionConfig(
+ num_bits=4, group_size=64,
+ group_dim=2, symmetric=False))
+
+ model = OptLM(f"facebook/{model_name}", env, shared.args.model_dir, policy)
+
+ # DeepSpeed ZeRO-3
+ elif shared.args.deepspeed:
+ model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
+ model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0]
+ model.module.eval() # Inference
+ print(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}")
+
+ # RMKV model (not on HuggingFace)
+ elif shared.model_type == 'rwkv':
+ from modules.RWKV import RWKVModel, RWKVTokenizer
+
+ model = RWKVModel.from_pretrained(Path(f'{shared.args.model_dir}/{model_name}'), dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", device="cpu" if shared.args.cpu else "cuda")
+ tokenizer = RWKVTokenizer.from_pretrained(Path(shared.args.model_dir))
+
+ return model, tokenizer
+
+ # llamacpp model
+ elif shared.model_type == 'llamacpp':
+ from modules.llamacpp_model_alternative import LlamaCppModel
+
+ path = Path(f'{shared.args.model_dir}/{model_name}')
+ if path.is_file():
+ model_file = path
+ else:
+ model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('*ggml*.bin'))[0]
+
+ print(f"llama.cpp weights detected: {model_file}\n")
+ model, tokenizer = LlamaCppModel.from_pretrained(model_file)
+ return model, tokenizer
+
+ # Quantized model
+ elif shared.args.wbits > 0:
+
+ # Monkey patch
+ if shared.args.monkey_patch:
+ print("Warning: applying the monkey patch for using LoRAs in 4-bit mode.\nIt may cause undefined behavior outside its intended scope.")
+ from modules.monkey_patch_gptq_lora import load_model_llama
+
+ model, tokenizer = load_model_llama(model_name)
+ return model, tokenizer
+
+ # No monkey patch
+ else:
+ from modules.GPTQ_loader import load_quantized
+
+ model = load_quantized(model_name)
+
+ # Custom
+ else:
+ params = {"low_cpu_mem_usage": True}
+ if not any((shared.args.cpu, torch.cuda.is_available(), torch.has_mps)):
+ print("Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been detected.\nFalling back to CPU mode.\n")
+ shared.args.cpu = True
+
+ if shared.args.cpu:
+ params["torch_dtype"] = torch.float32
+ else:
+ params["device_map"] = 'auto'
+ params["trust_remote_code"] = trust_remote_code
+ if shared.args.load_in_8bit and any((shared.args.auto_devices, shared.args.gpu_memory)):
+ params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True)
+ elif shared.args.load_in_8bit:
+ params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True)
+ elif shared.args.bf16:
+ params["torch_dtype"] = torch.bfloat16
+ else:
+ params["torch_dtype"] = torch.float16
+
+ if shared.args.gpu_memory:
+ memory_map = list(map(lambda x: x.strip(), shared.args.gpu_memory))
+ max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB'
+ max_memory = {}
+ for i in range(len(memory_map)):
+ max_memory[i] = f'{memory_map[i]}GiB' if not re.match('.*ib$', memory_map[i].lower()) else memory_map[i]
+ max_memory['cpu'] = max_cpu_memory
+ params['max_memory'] = max_memory
+ elif shared.args.auto_devices:
+ total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024 * 1024))
+ suggestion = round((total_mem - 1000) / 1000) * 1000
+ if total_mem - suggestion < 800:
+ suggestion -= 1000
+ suggestion = int(round(suggestion / 1000))
+ print(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m")
+
+ max_memory = {0: f'{suggestion}GiB', 'cpu': f'{shared.args.cpu_memory or 99}GiB'}
+ params['max_memory'] = max_memory
+
+ if shared.args.disk:
+ params["offload_folder"] = shared.args.disk_cache_dir
+
+ checkpoint = Path(f'{shared.args.model_dir}/{model_name}')
+
+ if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto':
+ config = AutoConfig.from_pretrained(checkpoint)
+ with init_empty_weights():
+ model = LoaderClass.from_config(config)
+ model.tie_weights()
+ params['device_map'] = infer_auto_device_map(
+ model,
+ dtype=torch.int8,
+ max_memory=params['max_memory'],
+ no_split_module_classes=model._no_split_modules
+ )
+
+ model = LoaderClass.from_pretrained(checkpoint, **params)
+
+ # Hijack attention with xformers
+ if any((shared.args.xformers, shared.args.sdp_attention)):
+ llama_attn_hijack.hijack_llama_attention()
+
+ # Loading the tokenizer
+ if shared.model_type == 'gpt4chan' and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
+ tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
+ elif type(model) is transformers.LlamaForCausalLM:
+ tokenizer = None
+
+ # Try to load an universal LLaMA tokenizer
+ if shared.model_type != 'llava':
+ for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]:
+ if p.exists():
+ print(f"Loading the universal LLaMA tokenizer from {p}...")
+ tokenizer = LlamaTokenizer.from_pretrained(p, clean_up_tokenization_spaces=True)
+ break
+
+ # Otherwise, load it from the model folder and hope that these
+ # are not outdated tokenizer files.
+ if tokenizer is None:
+ tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}/"), clean_up_tokenization_spaces=True)
+ try:
+ tokenizer.eos_token_id = 2
+ tokenizer.bos_token_id = 1
+ tokenizer.pad_token_id = 0
+ except:
+ pass
+ else:
+ tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}/"), trust_remote_code=trust_remote_code)
+
+ print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
+ return model, tokenizer
+
+
+def clear_torch_cache():
+ gc.collect()
+ if not shared.args.cpu:
+ torch.cuda.empty_cache()
+
+
+def unload_model():
+ shared.model = shared.tokenizer = None
+ clear_torch_cache()
+
+
+def reload_model():
+ unload_model()
+ shared.model, shared.tokenizer = load_model(shared.model_name)
+
+
+def load_soft_prompt(name):
+ if name == 'None':
+ shared.soft_prompt = False
+ shared.soft_prompt_tensor = None
+ else:
+ with zipfile.ZipFile(Path(f'softprompts/{name}.zip')) as zf:
+ zf.extract('tensor.npy')
+ zf.extract('meta.json')
+ j = json.loads(open('meta.json', 'r').read())
+ print(f"\nLoading the softprompt \"{name}\".")
+ for field in j:
+ if field != 'name':
+ if type(j[field]) is list:
+ print(f"{field}: {', '.join(j[field])}")
+ else:
+ print(f"{field}: {j[field]}")
+ print()
+ tensor = np.load('tensor.npy')
+ Path('tensor.npy').unlink()
+ Path('meta.json').unlink()
+ tensor = torch.Tensor(tensor).to(device=shared.model.device, dtype=shared.model.dtype)
+ tensor = torch.reshape(tensor, (1, tensor.shape[0], tensor.shape[1]))
+
+ shared.soft_prompt = True
+ shared.soft_prompt_tensor = tensor
+
+ return name
diff --git a/text-generation-webui-main/modules/monkey_patch_gptq_lora.py b/text-generation-webui-main/modules/monkey_patch_gptq_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e591b52486f445373b202fc03b62c36c7597ea9
--- /dev/null
+++ b/text-generation-webui-main/modules/monkey_patch_gptq_lora.py
@@ -0,0 +1,41 @@
+# Copied from https://github.com/johnsmith0031/alpaca_lora_4bit
+
+import sys
+from pathlib import Path
+
+sys.path.insert(0, str(Path("repositories/alpaca_lora_4bit")))
+
+import autograd_4bit
+from autograd_4bit import (Autograd4bitQuantLinear,
+ load_llama_model_4bit_low_ram)
+from monkeypatch.peft_tuners_lora_monkey_patch import (
+ Linear4bitLt, replace_peft_model_with_gptq_lora_model)
+
+from modules import shared
+from modules.GPTQ_loader import find_quantized_model_file
+
+replace_peft_model_with_gptq_lora_model()
+
+def load_model_llama(model_name):
+
+ config_path = str(Path(f'{shared.args.model_dir}/{model_name}'))
+ model_path = str(find_quantized_model_file(model_name))
+ model, tokenizer = load_llama_model_4bit_low_ram(config_path, model_path, groupsize=shared.args.groupsize, is_v1_model=False)
+
+ for n, m in model.named_modules():
+ if isinstance(m, Autograd4bitQuantLinear) or isinstance(m, Linear4bitLt):
+ if m.is_v1_model:
+ m.zeros = m.zeros.half()
+ m.scales = m.scales.half()
+ m.bias = m.bias.half()
+ autograd_4bit.use_new = True
+ autograd_4bit.auto_switch = True
+
+ try:
+ tokenizer.eos_token_id = 2
+ tokenizer.bos_token_id = 1
+ tokenizer.pad_token_id = 0
+ except:
+ pass
+
+ return model, tokenizer
diff --git a/text-generation-webui-main/modules/shared.py b/text-generation-webui-main/modules/shared.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a24f2206410ab7a6d1545e39f13f03f0f11867c
--- /dev/null
+++ b/text-generation-webui-main/modules/shared.py
@@ -0,0 +1,210 @@
+import argparse
+from pathlib import Path
+
+import yaml
+
+model = None
+tokenizer = None
+model_name = "None"
+model_type = None
+lora_names = []
+soft_prompt_tensor = None
+soft_prompt = False
+
+# Chat variables
+history = {'internal': [], 'visible': []}
+character = 'None'
+stop_everything = False
+processing_message = '*Is typing...*'
+
+# UI elements (buttons, sliders, HTML, etc)
+gradio = {}
+
+# For keeping the values of UI elements on page reload
+persistent_interface_state = {}
+
+# Generation input parameters
+input_params = []
+
+# For restarting the interface
+need_restart = False
+
+settings = {
+ 'max_new_tokens': 200,
+ 'max_new_tokens_min': 1,
+ 'max_new_tokens_max': 2000,
+ 'seed': -1,
+ 'name1': 'You',
+ 'name2': 'Assistant',
+ 'context': 'This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.',
+ 'greeting': '',
+ 'end_of_turn': '',
+ 'custom_stopping_strings': '',
+ 'stop_at_newline': False,
+ 'add_bos_token': True,
+ 'ban_eos_token': False,
+ 'skip_special_tokens': True,
+ 'truncation_length': 2048,
+ 'truncation_length_min': 0,
+ 'truncation_length_max': 8192,
+ 'mode': 'cai-chat',
+ 'instruction_template': 'None',
+ 'chat_prompt_size': 2048,
+ 'chat_prompt_size_min': 0,
+ 'chat_prompt_size_max': 2048,
+ 'chat_generation_attempts': 1,
+ 'chat_generation_attempts_min': 1,
+ 'chat_generation_attempts_max': 5,
+ 'default_extensions': [],
+ 'chat_default_extensions': ["gallery"],
+ 'presets': {
+ 'default': 'Default',
+ '.*(alpaca|llama|llava)': "LLaMA-Precise",
+ '.*pygmalion': 'NovelAI-Storywriter',
+ '.*RWKV': 'Naive',
+ },
+ 'prompts': {
+ 'default': 'QA',
+ '.*(gpt4chan|gpt-4chan|4chan)': 'GPT-4chan',
+ '.*oasst': 'Open Assistant',
+ '.*alpaca': "Alpaca",
+ },
+ 'lora_prompts': {
+ 'default': 'QA',
+ '.*alpaca': "Alpaca",
+ }
+}
+
+
+def str2bool(v):
+ if isinstance(v, bool):
+ return v
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
+ return True
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
+ return False
+ else:
+ raise argparse.ArgumentTypeError('Boolean value expected.')
+
+
+parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54))
+
+# Basic settings
+parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.')
+parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode with a style similar to the Character.AI website.')
+parser.add_argument('--cai-chat', action='store_true', help='DEPRECATED: use --chat instead.')
+parser.add_argument('--model', type=str, help='Name of the model to load by default.')
+parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
+parser.add_argument("--model-dir", type=str, default='models/', help="Path to directory with all the models")
+parser.add_argument("--lora-dir", type=str, default='loras/', help="Path to directory with all the loras")
+parser.add_argument('--model-menu', action='store_true', help='Show a model menu in the terminal when the web UI is first launched.')
+parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time.')
+parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example. If you create a file called settings.json, this file will be loaded by default without the need to use the --settings flag.')
+parser.add_argument('--extensions', type=str, nargs="+", help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.')
+parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
+
+# Accelerate/transformers
+parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text. Warning: Training on CPU is extremely slow.')
+parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
+parser.add_argument('--gpu-memory', type=str, nargs="+", help='Maxmimum GPU memory in GiB to be allocated per GPU. Example: --gpu-memory 10 for a single GPU, --gpu-memory 10 5 for two GPUs. You can also set values in MiB like --gpu-memory 3500MiB.')
+parser.add_argument('--cpu-memory', type=str, help='Maximum CPU memory in GiB to allocate for offloaded weights. Same as above.')
+parser.add_argument('--disk', action='store_true', help='If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.')
+parser.add_argument('--disk-cache-dir', type=str, default="cache", help='Directory to save the disk cache to. Defaults to "cache".')
+parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
+parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
+parser.add_argument('--no-cache', action='store_true', help='Set use_cache to False while generating text. This reduces the VRAM usage a bit at a performance cost.')
+parser.add_argument('--xformers', action='store_true', help="Use xformer's memory efficient attention. This should increase your tokens/s.")
+parser.add_argument('--sdp-attention', action='store_true', help="Use torch 2.0's sdp attention.")
+parser.add_argument('--trust-remote-code', action='store_true', help="Set trust_remote_code=True while loading a model. Necessary for ChatGLM.")
+
+# llama.cpp
+parser.add_argument('--threads', type=int, default=0, help='Number of threads to use in llama.cpp.')
+parser.add_argument('--n_batch', type=int, default=8, help='Processing batch size for llama.cpp.')
+
+# GPTQ
+parser.add_argument('--wbits', type=int, default=0, help='Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported.')
+parser.add_argument('--model_type', type=str, help='Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported.')
+parser.add_argument('--groupsize', type=int, default=-1, help='Group size.')
+parser.add_argument('--pre_layer', type=int, default=0, help='The number of layers to allocate to the GPU. Setting this parameter enables CPU offloading for 4-bit models.')
+parser.add_argument('--monkey-patch', action='store_true', help='Apply the monkey patch for using LoRAs with quantized models.')
+parser.add_argument('--quant_attn', action='store_true', help='(triton) Enable quant attention.')
+parser.add_argument('--warmup_autotune', action='store_true', help='(triton) Enable warmup autotune.')
+parser.add_argument('--fused_mlp', action='store_true', help='(triton) Enable fused mlp.')
+
+# FlexGen
+parser.add_argument('--flexgen', action='store_true', help='Enable the use of FlexGen offloading.')
+parser.add_argument('--percent', type=int, nargs="+", default=[0, 100, 100, 0, 100, 0], help='FlexGen: allocation percentages. Must be 6 numbers separated by spaces (default: 0, 100, 100, 0, 100, 0).')
+parser.add_argument("--compress-weight", action="store_true", help="FlexGen: activate weight compression.")
+parser.add_argument("--pin-weight", type=str2bool, nargs="?", const=True, default=True, help="FlexGen: whether to pin weights (setting this to False reduces CPU memory by 20%%).")
+
+# DeepSpeed
+parser.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.')
+parser.add_argument('--nvme-offload-dir', type=str, help='DeepSpeed: Directory to use for ZeRO-3 NVME offloading.')
+parser.add_argument('--local_rank', type=int, default=0, help='DeepSpeed: Optional argument for distributed setups.')
+
+# RWKV
+parser.add_argument('--rwkv-strategy', type=str, default=None, help='RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8".')
+parser.add_argument('--rwkv-cuda-on', action='store_true', help='RWKV: Compile the CUDA kernel for better performance.')
+
+# Gradio
+parser.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.')
+parser.add_argument('--listen-host', type=str, help='The hostname that the server will use.')
+parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.')
+parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.')
+parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch.')
+parser.add_argument("--gradio-auth-path", type=str, help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"', default=None)
+
+# API
+parser.add_argument('--api', action='store_true', help='Enable the API extension.')
+parser.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudfare.')
+
+
+args = parser.parse_args()
+args_defaults = parser.parse_args([])
+
+# Deprecation warnings for parameters that have been renamed
+deprecated_dict = {}
+for k in deprecated_dict:
+ if getattr(args, k) != deprecated_dict[k][1]:
+ print(f"Warning: --{k} is deprecated and will be removed. Use --{deprecated_dict[k][0]} instead.\n")
+ setattr(args, deprecated_dict[k][0], getattr(args, k))
+
+# Deprecation warnings for parameters that have been removed
+if args.cai_chat:
+ print("Warning: --cai-chat is deprecated. Use --chat instead.\n")
+ args.chat = True
+
+# Security warnings
+if args.trust_remote_code:
+ print("Warning: trust_remote_code is enabled. This is dangerous.\n")
+if args.share:
+ print("Warning: the gradio \"share link\" feature downloads a proprietary and\nunaudited blob to create a reverse tunnel. This is potentially dangerous.\n")
+
+# Activating the API extension
+if args.api or args.public_api:
+ if args.extensions is None:
+ args.extensions = ['api']
+ elif 'api' not in args.extensions:
+ args.extensions.append('api')
+
+
+def is_chat():
+ return args.chat
+
+
+# Loading model-specific settings (default)
+with Path(f'{args.model_dir}/config.yaml') as p:
+ if p.exists():
+ model_config = yaml.safe_load(open(p, 'r').read())
+ else:
+ model_config = {}
+
+# Applying user-defined model settings
+with Path(f'{args.model_dir}/config-user.yaml') as p:
+ if p.exists():
+ user_config = yaml.safe_load(open(p, 'r').read())
+ for k in user_config:
+ if k in model_config:
+ model_config[k].update(user_config[k])
+ else:
+ model_config[k] = user_config[k]
diff --git a/text-generation-webui-main/modules/text_generation.py b/text-generation-webui-main/modules/text_generation.py
new file mode 100644
index 0000000000000000000000000000000000000000..032fc84c2ca703546d0d1246bd1bdab9f7583ba4
--- /dev/null
+++ b/text-generation-webui-main/modules/text_generation.py
@@ -0,0 +1,316 @@
+import ast
+import random
+import re
+import time
+import traceback
+
+import numpy as np
+import torch
+import transformers
+
+import modules.shared as shared
+from modules.callbacks import (Iteratorize, Stream,
+ _SentinelTokenStoppingCriteria)
+from modules.extensions import apply_extensions
+from modules.html_generator import generate_4chan_html, generate_basic_html
+from modules.models import clear_torch_cache, local_rank
+
+
+def get_max_prompt_length(state):
+ max_length = state['truncation_length'] - state['max_new_tokens']
+ if shared.soft_prompt:
+ max_length -= shared.soft_prompt_tensor.shape[1]
+ return max_length
+
+
+def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
+ if shared.model_type in ['rwkv', 'llamacpp']:
+ input_ids = shared.tokenizer.encode(str(prompt))
+ input_ids = np.array(input_ids).reshape(1, len(input_ids))
+ return input_ids
+ else:
+ input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', add_special_tokens=add_special_tokens)
+
+ # This is a hack for making replies more creative.
+ if not add_bos_token and input_ids[0][0] == shared.tokenizer.bos_token_id:
+ input_ids = input_ids[:, 1:]
+
+ # Llama adds this extra token when the first character is '\n', and this
+ # compromises the stopping criteria, so we just remove it
+ if type(shared.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871:
+ input_ids = input_ids[:, 1:]
+
+ # Handling truncation
+ if truncation_length is not None:
+ input_ids = input_ids[:, -truncation_length:]
+
+ if shared.model_type in ['rwkv', 'llamacpp'] or shared.args.cpu:
+ return input_ids
+ elif shared.args.flexgen:
+ return input_ids.numpy()
+ elif shared.args.deepspeed:
+ return input_ids.to(device=local_rank)
+ elif torch.has_mps:
+ device = torch.device('mps')
+ return input_ids.to(device)
+ else:
+ return input_ids.cuda()
+
+
+def decode(output_ids, skip_special_tokens=True):
+ if skip_special_tokens:
+ reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True)
+ reply = reply.replace(r'<|endoftext|>', '')
+ return reply
+ else:
+ return shared.tokenizer.decode(output_ids, skip_special_tokens=False)
+
+
+def generate_softprompt_input_tensors(input_ids):
+ inputs_embeds = shared.model.transformer.wte(input_ids)
+ inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
+ filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(shared.model.device)
+ # filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
+ return inputs_embeds, filler_input_ids
+
+
+# Removes empty replies from gpt4chan outputs
+def fix_gpt4chan(s):
+ for i in range(10):
+ s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
+ s = re.sub("--- [0-9]*\n *\n---", "---", s)
+ s = re.sub("--- [0-9]*\n\n\n---", "---", s)
+ return s
+
+
+# Fix the LaTeX equations in galactica
+def fix_galactica(s):
+ s = s.replace(r'\[', r'$')
+ s = s.replace(r'\]', r'$')
+ s = s.replace(r'\(', r'$')
+ s = s.replace(r'\)', r'$')
+ s = s.replace(r'$$', r'$')
+ s = re.sub(r'\n', r'\n\n', s)
+ s = re.sub(r"\n{3,}", "\n\n", s)
+ return s
+
+
+def formatted_outputs(reply, model_name):
+ if not shared.is_chat():
+ if shared.model_type == 'galactica':
+ reply = fix_galactica(reply)
+ return reply, reply, generate_basic_html(reply)
+ elif shared.model_type == 'gpt4chan':
+ reply = fix_gpt4chan(reply)
+ return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
+ else:
+ return reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply)
+ else:
+ return reply
+
+
+def set_manual_seed(seed):
+ seed = int(seed)
+ if seed == -1:
+ seed = random.randint(1, 2**31)
+ torch.manual_seed(seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(seed)
+ return seed
+
+
+def stop_everything_event():
+ shared.stop_everything = True
+
+
+def generate_reply(question, state, eos_token=None, stopping_strings=[]):
+
+ if shared.model_name == 'None' or shared.model is None:
+ print("No model is loaded! Select one in the Model tab.")
+ yield formatted_outputs(question, shared.model_name)
+ return
+
+ clear_torch_cache()
+ seed = set_manual_seed(state['seed'])
+ shared.stop_everything = False
+ generate_params = {}
+ t0 = time.time()
+
+ original_question = question
+ if not shared.is_chat():
+ question = apply_extensions('input', question)
+
+ # These models are not part of Hugging Face, so we handle them
+ # separately and terminate the function call earlier
+ if shared.model_type in ['rwkv', 'llamacpp']:
+
+ if shared.args.verbose:
+ print(f'\n\n{question}\n--------------------\n')
+
+ for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
+ generate_params[k] = state[k]
+ generate_params['token_count'] = state['max_new_tokens']
+ try:
+ if shared.args.no_stream:
+ reply = shared.model.generate(context=question, **generate_params)
+ output = original_question + reply
+ if not shared.is_chat():
+ reply = original_question + apply_extensions('output', reply)
+ yield formatted_outputs(reply, shared.model_name)
+ else:
+ if not shared.is_chat():
+ yield formatted_outputs(question, shared.model_name)
+
+ # RWKV has proper streaming, which is very nice.
+ # No need to generate 8 tokens at a time.
+ for reply in shared.model.generate_with_streaming(context=question, **generate_params):
+ output = original_question + reply
+ if not shared.is_chat():
+ reply = original_question + apply_extensions('output', reply)
+ yield formatted_outputs(reply, shared.model_name)
+
+ except Exception:
+ traceback.print_exc()
+ finally:
+ t1 = time.time()
+ original_tokens = len(encode(original_question)[0])
+ new_tokens = len(encode(output)[0]) - original_tokens
+ print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
+ return
+
+ input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
+ output = input_ids[0]
+
+ if shared.args.verbose:
+ print(f'\n\n{decode(input_ids[0], state["skip_special_tokens"])}\n--------------------\n')
+
+ cuda = not any((shared.args.cpu, shared.args.deepspeed, shared.args.flexgen))
+ eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
+ if eos_token is not None:
+ eos_token_ids.append(int(encode(eos_token)[0][-1]))
+
+ # Handling the stopping strings
+ stopping_criteria_list = transformers.StoppingCriteriaList()
+ for st in (stopping_strings, ast.literal_eval(f"[{state['custom_stopping_strings']}]")):
+ if type(st) is list and len(st) > 0:
+ sentinel_token_ids = [encode(string, add_special_tokens=False) for string in st]
+ stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=sentinel_token_ids, starting_idx=len(input_ids[0])))
+ break
+
+ if not shared.args.flexgen:
+ for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']:
+ generate_params[k] = state[k]
+ generate_params['eos_token_id'] = eos_token_ids
+ generate_params['stopping_criteria'] = stopping_criteria_list
+ if state['ban_eos_token']:
+ generate_params['suppress_tokens'] = [shared.tokenizer.eos_token_id]
+ else:
+ for k in ['max_new_tokens', 'do_sample', 'temperature']:
+ generate_params[k] = state[k]
+ generate_params['stop'] = eos_token_ids[-1]
+ if not shared.args.no_stream:
+ generate_params['max_new_tokens'] = 8
+
+ if shared.args.no_cache:
+ generate_params.update({'use_cache': False})
+ if shared.args.deepspeed:
+ generate_params.update({'synced_gpus': True})
+ if shared.soft_prompt:
+ inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
+ question, filler_input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, filler_input_ids, inputs_embeds)
+ original_input_ids = input_ids
+ generate_params.update({'inputs_embeds': inputs_embeds})
+ generate_params.update({'inputs': filler_input_ids})
+ else:
+ question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
+ original_input_ids = input_ids
+ generate_params.update({'inputs': input_ids})
+ if inputs_embeds is not None:
+ generate_params.update({'inputs_embeds': inputs_embeds})
+
+ try:
+ # Generate the entire reply at once.
+ if shared.args.no_stream:
+ with torch.no_grad():
+ output = shared.model.generate(**generate_params)[0]
+ if cuda:
+ output = output.cuda()
+
+ if shared.soft_prompt:
+ output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
+
+ new_tokens = len(output) - len(input_ids[0])
+ reply = decode(output[-new_tokens:], state['skip_special_tokens'])
+ if not shared.is_chat():
+ reply = original_question + apply_extensions('output', reply)
+
+ yield formatted_outputs(reply, shared.model_name)
+
+ # Stream the reply 1 token at a time.
+ # This is based on the trick of using 'stopping_criteria' to create an iterator.
+ elif not shared.args.flexgen:
+
+ def generate_with_callback(callback=None, **kwargs):
+ kwargs['stopping_criteria'].append(Stream(callback_func=callback))
+ clear_torch_cache()
+ with torch.no_grad():
+ shared.model.generate(**kwargs)
+
+ def generate_with_streaming(**kwargs):
+ return Iteratorize(generate_with_callback, kwargs, callback=None)
+
+ if not shared.is_chat():
+ yield formatted_outputs(original_question, shared.model_name)
+
+ with generate_with_streaming(**generate_params) as generator:
+ for output in generator:
+ if shared.soft_prompt:
+ output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
+
+ new_tokens = len(output) - len(input_ids[0])
+ reply = decode(output[-new_tokens:], state['skip_special_tokens'])
+ if not shared.is_chat():
+ reply = original_question + apply_extensions('output', reply)
+
+ if output[-1] in eos_token_ids:
+ break
+
+ yield formatted_outputs(reply, shared.model_name)
+
+ # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
+ else:
+ for i in range(state['max_new_tokens'] // 8 + 1):
+ clear_torch_cache()
+ with torch.no_grad():
+ output = shared.model.generate(**generate_params)[0]
+
+ if shared.soft_prompt:
+ output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
+
+ new_tokens = len(output) - len(original_input_ids[0])
+ reply = decode(output[-new_tokens:], state['skip_special_tokens'])
+ if not shared.is_chat():
+ reply = original_question + apply_extensions('output', reply)
+
+ if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
+ break
+
+ yield formatted_outputs(reply, shared.model_name)
+ input_ids = np.reshape(output, (1, output.shape[0]))
+ if shared.soft_prompt:
+ inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
+ generate_params.update({'inputs_embeds': inputs_embeds})
+ generate_params.update({'inputs': filler_input_ids})
+ else:
+ generate_params.update({'inputs': input_ids})
+
+ yield formatted_outputs(reply, shared.model_name)
+
+ except Exception:
+ traceback.print_exc()
+ finally:
+ t1 = time.time()
+ original_tokens = len(original_input_ids[0])
+ new_tokens = len(output) - original_tokens
+ print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
+ return
diff --git a/text-generation-webui-main/modules/training.py b/text-generation-webui-main/modules/training.py
new file mode 100644
index 0000000000000000000000000000000000000000..cde4a555e57fb6a80fd9c4b5edc8c106a1e8b64c
--- /dev/null
+++ b/text-generation-webui-main/modules/training.py
@@ -0,0 +1,499 @@
+import json
+import math
+import sys
+import threading
+import time
+import traceback
+from pathlib import Path
+
+import gradio as gr
+import torch
+import transformers
+from datasets import Dataset, load_dataset
+from peft import (LoraConfig, get_peft_model, prepare_model_for_int8_training,
+ set_peft_model_state_dict)
+
+from modules import shared, ui
+from modules.evaluate import calculate_perplexity, generate_markdown_table, save_past_evaluations
+from server import get_available_loras, get_available_models
+
+# This mapping is from a very recent commit, not yet released.
+# If not available, default to a backup map for some common model types.
+try:
+ from peft.utils.other import \
+ TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as \
+ model_to_lora_modules
+except:
+ standard_modules = ["q_proj", "v_proj"]
+ model_to_lora_modules = {"llama": standard_modules, "opt": standard_modules, "gptj": standard_modules, "gpt_neox": ["query_key_value"]}
+
+WANT_INTERRUPT = False
+
+PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "do_shuffle", "higher_rank_limit", "warmup_steps", "optimizer"]
+
+# Mapping of Python class names to peft IDs
+MODEL_CLASSES = {
+ "LlamaForCausalLM": "llama",
+ "OPTForCausalLM": "opt",
+ "GPTJForCausalLM": "gptj",
+ "GPTNeoXForCausalLM": "gpt_neox"
+}
+
+
+def get_datasets(path: str, ext: str):
+ return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=str.lower)
+
+
+def create_train_interface():
+ with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
+ gr.Markdown("Confused? [[Click here for a guide]](https://github.com/oobabooga/text-generation-webui/blob/main/docs/Training-LoRAs.md)")
+
+ with gr.Row():
+ lora_name = gr.Textbox(label='Name', info='The name of your new LoRA file')
+ always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name given is the same as an existing file, checking this will replace that file. Leaving unchecked will load that file and continue from it (must use the same rank value as the original had).')
+ save_steps = gr.Number(label='Save every n steps', value=0, info='If above 0, a checkpoint of the LoRA will be saved every time this many steps pass.')
+
+ with gr.Row():
+ copy_from = gr.Dropdown(label='Copy parameters from', value='None', choices=get_available_loras())
+ ui.create_refresh_button(copy_from, lambda: None, lambda: {'choices': get_available_loras()}, 'refresh-button')
+
+ with gr.Row():
+ # TODO: Implement multi-device support.
+ micro_batch_size = gr.Slider(label='Micro Batch Size', value=4, minimum=1, maximum=128, step=1, info='Per-device batch size (NOTE: multiple devices not yet implemented). Increasing this will increase VRAM usage.')
+ batch_size = gr.Slider(label='Batch Size', value=128, minimum=0, maximum=1024, step=4, info='Global batch size. The two batch sizes together determine gradient accumulation (gradientAccum = batch / microBatch). Higher gradient accum values lead to better quality training.')
+
+ with gr.Row():
+ epochs = gr.Number(label='Epochs', value=3, info='Number of times every entry in the dataset should be fed into training. So 1 means feed each item in once, 5 means feed it in five times, etc.')
+ learning_rate = gr.Textbox(label='Learning Rate', value='3e-4', info='Learning rate, in scientific notation. 3e-4 is a good starting base point. 1e-2 is extremely high, 1e-6 is extremely low.')
+ lr_scheduler_type = gr.Dropdown(label='LR Scheduler', value='linear', choices=['linear', 'constant', 'constant_with_warmup', 'cosine', 'cosine_with_restarts', 'polynomial', 'inverse_sqrt'], info='Learning rate scheduler - defines how the learning rate changes over time. "Constant" means never change, "linear" means to go in a straight line from the learning rate down to 0, cosine follows a curve, etc.')
+
+ # TODO: What is the actual maximum rank? Likely distinct per model. This might be better to somehow be on a log scale.
+ lora_rank = gr.Slider(label='LoRA Rank', value=32, minimum=0, maximum=1024, step=4, info='LoRA Rank, or dimension count. Higher values produce a larger file with better control over the model\'s content. Smaller values produce a smaller file with less overall control. Small values like 4 or 8 are great for stylistic guidance, higher values like 128 or 256 are good for teaching content upgrades, extremely high values (1024+) are difficult to train but may improve fine-detail learning for large datasets. Higher ranks also require higher VRAM.')
+ lora_alpha = gr.Slider(label='LoRA Alpha', value=64, minimum=0, maximum=2048, step=4, info='LoRA Alpha. This divided by the rank becomes the scaling of the LoRA. Higher means stronger. A good standard value is twice your Rank.')
+
+ cutoff_len = gr.Slider(label='Cutoff Length', minimum=0, maximum=2048, value=256, step=32, info='Cutoff length for text input. Essentially, how long of a line of text to feed in at a time. Higher values require drastically more VRAM.')
+
+ with gr.Tab(label='Formatted Dataset'):
+ with gr.Row():
+ dataset = gr.Dropdown(choices=get_datasets('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.')
+ ui.create_refresh_button(dataset, lambda: None, lambda: {'choices': get_datasets('training/datasets', 'json')}, 'refresh-button')
+ eval_dataset = gr.Dropdown(choices=get_datasets('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.')
+ ui.create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': get_datasets('training/datasets', 'json')}, 'refresh-button')
+ format = gr.Dropdown(choices=get_datasets('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.')
+ ui.create_refresh_button(format, lambda: None, lambda: {'choices': get_datasets('training/formats', 'json')}, 'refresh-button')
+
+ eval_steps = gr.Number(label='Evaluate every n steps', value=100, info='If an evaluation dataset is given, test it every time this many steps pass.')
+
+ with gr.Tab(label="Raw text file"):
+ with gr.Row():
+ raw_text_file = gr.Dropdown(choices=get_datasets('training/datasets', 'txt'), value='None', label='Text file', info='The raw text file to use for training.')
+ ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': get_datasets('training/datasets', 'txt')}, 'refresh-button')
+
+ with gr.Row():
+ overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length below). Setting overlap to exactly half the cutoff length may be ideal.')
+ newline_favor_len = gr.Slider(label='Prefer Newline Cut Length', minimum=0, maximum=512, value=128, step=16, info='Length (in characters, not tokens) of the maximum distance to shift an overlap cut by to ensure chunks cut at newlines. If too low, cuts may occur in the middle of lines.')
+
+ with gr.Accordion(label='Advanced Options', open=False):
+ lora_dropout = gr.Slider(label='LoRA Dropout', minimum=0.0, maximum=1.0, step=0.025, value=0.05, info='Percentage probability for dropout of LoRA layers. This can help reduce overfitting. Most users should leave at default.')
+ warmup_steps = gr.Number(label='Warmup Steps', value=100, info='For this many steps at the start, the learning rate will be lower than normal. This helps the trainer prepare the model and precompute statistics to improve the quality of training after the start.')
+ optimizer = gr.Dropdown(label='Optimizer', value='adamw_torch', choices=['adamw_hf', 'adamw_torch', 'adamw_torch_fused', 'adamw_torch_xla', 'adamw_apex_fused', 'adafactor', 'adamw_bnb_8bit', 'adamw_anyprecision', 'sgd', 'adagrad'], info='Different optimizer implementation options, for advanced users. Effects of different options are not well documented yet.')
+
+ with gr.Row():
+ do_shuffle = gr.Checkbox(label='Shuffle Dataset', value=True, info='If checked, the dataset will be randomly shuffled. This can help reduce overfitting.')
+ higher_rank_limit = gr.Checkbox(label='Enable higher ranks', value=False, info='If checked, changes Rank/Alpha slider above to go much higher. This will not work without a datacenter-class GPU.')
+
+ with gr.Row():
+ start_button = gr.Button("Start LoRA Training")
+ stop_button = gr.Button("Interrupt")
+
+ output = gr.Markdown(value="Ready")
+
+ with gr.Tab('Perplexity evaluation', elem_id='evaluate-tab'):
+ with gr.Row():
+ with gr.Column():
+ models = gr.Dropdown(get_available_models(), label='Models', multiselect=True)
+ evaluate_text_file = gr.Dropdown(choices=['wikitext', 'ptb', 'ptb_new'] + get_datasets('training/datasets', 'txt')[1:], value='wikitext', label='Input dataset', info='The raw text file on which the model will be evaluated. The first options are automatically downloaded: wikitext, ptb, and ptb_new. The next options are your local text files under training/datasets.')
+ with gr.Row():
+ stride_length = gr.Slider(label='Stride', minimum=1, maximum=2048, value=512, step=1, info='Used to make the evaluation faster at the cost of accuracy. 1 = slowest but most accurate. 512 is a common value.')
+ max_length = gr.Slider(label='max_length', minimum=0, maximum=8096, value=0, step=1, info='The context for each evaluation. If set to 0, the maximum context length for the model will be used.')
+
+ with gr.Row():
+ start_current_evaluation = gr.Button("Evaluate loaded model")
+ start_evaluation = gr.Button("Evaluate selected models")
+ stop_evaluation = gr.Button("Interrupt")
+
+ with gr.Column():
+ evaluation_log = gr.Markdown(value = '')
+
+ evaluation_table = gr.Dataframe(value=generate_markdown_table(), interactive=True)
+ save_comments = gr.Button('Save comments')
+
+ # Training events
+ all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, do_shuffle, higher_rank_limit, warmup_steps, optimizer]
+ copy_from.change(do_copy_params, [copy_from] + all_params, all_params)
+ start_button.click(do_train, all_params, output)
+ stop_button.click(do_interrupt, None, None, queue=False)
+ higher_rank_limit.change(change_rank_limit, [higher_rank_limit], [lora_rank, lora_alpha])
+
+ # Evaluation events. For some reason, the interrupt event
+ # doesn't work with the .then() syntax, so I write them one
+ # by one in this ugly but functional way.
+ ev = start_evaluation.click(calculate_perplexity, [models, evaluate_text_file, stride_length, max_length], evaluation_log, show_progress=False)
+ start_evaluation.click(generate_markdown_table, None, evaluation_table, show_progress=False)
+
+ tmp = gr.State('')
+ start_current_evaluation.click(lambda: ['current model'], None, tmp)
+ ev_cur = start_current_evaluation.click(calculate_perplexity, [tmp, evaluate_text_file, stride_length, max_length], evaluation_log, show_progress=False)
+ start_current_evaluation.click(generate_markdown_table, None, evaluation_table, show_progress=False)
+
+ stop_evaluation.click(None, None, None, cancels=[ev, ev_cur], queue=False)
+ save_comments.click(
+ save_past_evaluations, evaluation_table, None).then(
+ lambda: "Comments saved.", None, evaluation_log, show_progress=False)
+
+
+def do_interrupt():
+ global WANT_INTERRUPT
+ WANT_INTERRUPT = True
+
+
+def do_copy_params(lora_name: str, *args):
+ f_name = f"{shared.args.lora_dir}/{clean_path(None, lora_name)}/training_parameters.json"
+ if Path(f_name).is_file():
+ with open(f_name, 'r', encoding='utf-8') as format_file:
+ params: dict[str, str] = json.load(format_file)
+ else:
+ params = {}
+
+ result = list()
+ for i in range(0, len(PARAMETERS)):
+ key = PARAMETERS[i]
+ if key in params:
+ result.append(params[key])
+ else:
+ result.append(args[i])
+
+ return result
+
+
+def change_rank_limit(use_higher_ranks: bool):
+ mult = 2 if use_higher_ranks else 1
+ return {"maximum": 1024 * mult, "__type__": "update"}, {"maximum": 2048 * mult, "__type__": "update"}
+
+
+def clean_path(base_path: str, path: str):
+ """"Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
+ # TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path.
+ # Or swap it to a strict whitelist of [a-zA-Z_0-9]
+ path = path.replace('\\', '/').replace('..', '_')
+ if base_path is None:
+ return path
+
+ return f'{Path(base_path).absolute()}/{path}'
+
+
+def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, do_shuffle: bool, higher_rank_limit: bool, warmup_steps: int, optimizer: str):
+
+ if shared.args.monkey_patch:
+ from monkeypatch.peft_tuners_lora_monkey_patch import \
+ replace_peft_model_with_gptq_lora_model
+ replace_peft_model_with_gptq_lora_model()
+
+ global WANT_INTERRUPT
+ WANT_INTERRUPT = False
+
+ # == Input validation / processing ==
+ yield "Prepping..."
+ lora_file_path = clean_path(None, lora_name)
+ if lora_file_path.strip() == '':
+ yield "Missing or invalid LoRA file name input."
+ return
+
+ lora_file_path = f"{shared.args.lora_dir}/{lora_file_path}"
+ actual_lr = float(learning_rate)
+ model_type = type(shared.model).__name__
+
+ if model_type in MODEL_CLASSES:
+ model_id = MODEL_CLASSES[model_type]
+ else:
+ model_id = "llama"
+ if model_type == "PeftModelForCausalLM":
+ if len(shared.args.lora_names) > 0:
+ yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
+ print("Warning: Training LoRA over top of another LoRA. May have unexpected effects.")
+ else:
+ yield "Model ID not matched due to LoRA loading. Consider reloading base model. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
+ print("Warning: Model ID not matched due to LoRA loading. Consider reloading base model.")
+ else:
+ yield "LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
+ print(f"Warning: LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. (Found model type: {model_type})")
+ time.sleep(5)
+
+ if shared.args.wbits > 0 and not shared.args.monkey_patch:
+ yield "LoRA training in 4-bit requires loading with `--monkey-patch`"
+ return
+
+ elif not shared.args.load_in_8bit and shared.args.wbits <= 0:
+ yield "It is highly recommended you use `--load-in-8bit` for LoRA training. *(Will continue anyway in 2 seconds, press `Interrupt` to stop.)*"
+ print("Warning: It is highly recommended you use `--load-in-8bit` for LoRA training.")
+ time.sleep(2) # Give it a moment for the message to show in UI before continuing
+
+ if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0:
+ yield "Cannot input zeroes."
+ return
+
+ gradient_accumulation_steps = batch_size // micro_batch_size
+ shared.tokenizer.pad_token = 0
+ shared.tokenizer.padding_side = "left"
+
+ def tokenize(prompt):
+ result = shared.tokenizer(prompt, truncation=True, max_length=cutoff_len + 1, padding="max_length")
+ return {
+ "input_ids": result["input_ids"][:-1],
+ "attention_mask": result["attention_mask"][:-1],
+ }
+
+ # == Prep the dataset, format, etc ==
+ if raw_text_file not in ['None', '']:
+ print("Loading raw text file dataset...")
+ with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file:
+ raw_text = file.read()
+
+ tokens = shared.tokenizer.encode(raw_text)
+ del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
+ tokens = list(split_chunks(tokens, cutoff_len - overlap_len))
+ for i in range(1, len(tokens)):
+ tokens[i] = tokens[i - 1][-overlap_len:] + tokens[i]
+
+ text_chunks = [shared.tokenizer.decode(x) for x in tokens]
+ del tokens
+ if newline_favor_len > 0:
+ text_chunks = [cut_chunk_for_newline(x, newline_favor_len) for x in text_chunks]
+
+ train_data = Dataset.from_list([tokenize(x) for x in text_chunks])
+ del text_chunks
+ eval_data = None
+
+ else:
+ if dataset in ['None', '']:
+ yield "**Missing dataset choice input, cannot continue.**"
+ return
+
+ if format in ['None', '']:
+ yield "**Missing format choice input, cannot continue.**"
+ return
+
+ with open(clean_path('training/formats', f'{format}.json'), 'r', encoding='utf-8') as formatFile:
+ format_data: dict[str, str] = json.load(formatFile)
+
+ def generate_prompt(data_point: dict[str, str]):
+ for options, data in format_data.items():
+ if set(options.split(',')) == set(x[0] for x in data_point.items() if (x[1] is not None and len(x[1].strip()) > 0)):
+ for key, val in data_point.items():
+ if val is not None:
+ data = data.replace(f'%{key}%', val)
+ return data
+ raise RuntimeError(f'Data-point "{data_point}" has no keyset match within format "{list(format_data.keys())}"')
+
+ def generate_and_tokenize_prompt(data_point):
+ prompt = generate_prompt(data_point)
+ return tokenize(prompt)
+
+ print("Loading JSON datasets...")
+ data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json'))
+ train_data = data['train'].map(generate_and_tokenize_prompt)
+
+ if eval_dataset == 'None':
+ eval_data = None
+ else:
+ eval_data = load_dataset("json", data_files=clean_path('training/datasets', f'{eval_dataset}.json'))
+ eval_data = eval_data['train'].map(generate_and_tokenize_prompt)
+ if do_shuffle:
+ eval_data = eval_data.shuffle()
+
+ if do_shuffle:
+ train_data = train_data.shuffle()
+
+ # == Start prepping the model itself ==
+ if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
+ print("Getting model ready...")
+ prepare_model_for_int8_training(shared.model)
+
+ print("Prepping for training...")
+ config = LoraConfig(
+ r=lora_rank,
+ lora_alpha=lora_alpha,
+ target_modules=model_to_lora_modules[model_id],
+ lora_dropout=lora_dropout,
+ bias="none",
+ task_type="CAUSAL_LM"
+ )
+
+ try:
+ print("Creating LoRA model...")
+ lora_model = get_peft_model(shared.model, config)
+ if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file():
+ print("Loading existing LoRA data...")
+ state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin")
+ set_peft_model_state_dict(lora_model, state_dict_peft)
+ except:
+ yield traceback.format_exc()
+ return
+
+ if shared.args.monkey_patch:
+ for n, m in lora_model.named_modules():
+ if '4bit' in str(type(m)):
+ if m.is_v1_model:
+ m.zeros = m.zeros.half()
+
+ m.scales = m.scales.half()
+
+ class Tracked():
+ def __init__(self):
+ self.current_steps = 0
+ self.max_steps = 0
+ self.did_save = False
+
+ tracked = Tracked()
+ actual_save_steps = math.ceil(save_steps / gradient_accumulation_steps)
+
+ class Callbacks(transformers.TrainerCallback):
+ def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
+ tracked.current_steps = state.global_step * gradient_accumulation_steps
+ tracked.max_steps = state.max_steps * gradient_accumulation_steps
+ if WANT_INTERRUPT:
+ control.should_epoch_stop = True
+ control.should_training_stop = True
+ elif state.global_step > 0 and actual_save_steps > 0 and state.global_step % actual_save_steps == 0:
+ lora_model.save_pretrained(f"{lora_file_path}/checkpoint-{tracked.current_steps}/")
+
+ def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
+ tracked.current_steps += 1
+ if WANT_INTERRUPT:
+ control.should_epoch_stop = True
+ control.should_training_stop = True
+
+ trainer = transformers.Trainer(
+ model=lora_model,
+ train_dataset=train_data,
+ eval_dataset=eval_data,
+ args=transformers.TrainingArguments(
+ per_device_train_batch_size=micro_batch_size,
+ gradient_accumulation_steps=gradient_accumulation_steps,
+ warmup_steps=math.ceil(warmup_steps / gradient_accumulation_steps),
+ num_train_epochs=epochs,
+ learning_rate=actual_lr,
+ fp16=False if shared.args.cpu else True,
+ optim=optimizer,
+ logging_steps=5,
+ evaluation_strategy="steps" if eval_data is not None else "no",
+ eval_steps=math.ceil(eval_steps / gradient_accumulation_steps) if eval_data is not None else None,
+ save_strategy="no",
+ output_dir=lora_file_path,
+ lr_scheduler_type=lr_scheduler_type,
+ load_best_model_at_end=True if eval_data is not None else False,
+ # TODO: Enable multi-device support
+ ddp_find_unused_parameters=None,
+ no_cuda=shared.args.cpu
+ ),
+ data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False),
+ callbacks=list([Callbacks()])
+ )
+
+ lora_model.config.use_cache = False
+
+ if torch.__version__ >= "2" and sys.platform != "win32":
+ lora_model = torch.compile(lora_model)
+
+ # == Save parameters for reuse ==
+ with open(f"{lora_file_path}/training_parameters.json", 'w', encoding='utf-8') as file:
+ vars = locals()
+ json.dump({x: vars[x] for x in PARAMETERS}, file)
+
+ # == Main run and monitor loop ==
+ print("Starting training...")
+ yield "Starting..."
+ if WANT_INTERRUPT:
+ yield "Interrupted before start."
+ return
+
+ def threaded_run():
+ trainer.train()
+ # Note: save in the thread in case the gradio thread breaks (eg browser closed)
+ lora_model.save_pretrained(lora_file_path)
+ print("LoRA training run is completed and saved.")
+ tracked.did_save = True
+
+ thread = threading.Thread(target=threaded_run)
+ thread.start()
+ last_step = 0
+ start_time = time.perf_counter()
+
+ while thread.is_alive():
+ time.sleep(0.5)
+ if WANT_INTERRUPT:
+ yield "Interrupting, please wait... *(Run will stop after the current training step completes.)*"
+
+ elif tracked.current_steps != last_step:
+ last_step = tracked.current_steps
+ time_elapsed = time.perf_counter() - start_time
+ if time_elapsed <= 0:
+ timer_info = ""
+ total_time_estimate = 999
+ else:
+ its = tracked.current_steps / time_elapsed
+ if its > 1:
+ timer_info = f"`{its:.2f}` it/s"
+ else:
+ timer_info = f"`{1.0/its:.2f}` s/it"
+
+ total_time_estimate = (1.0 / its) * (tracked.max_steps)
+
+ yield f"Running... **{tracked.current_steps}** / **{tracked.max_steps}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining"
+
+ # Saving in the train thread might fail if an error occurs, so save here if so.
+ if not tracked.did_save:
+ print("Training complete, saving...")
+ lora_model.save_pretrained(lora_file_path)
+
+ if WANT_INTERRUPT:
+ print("Training interrupted.")
+ yield f"Interrupted. Incomplete LoRA saved to `{lora_file_path}`"
+ else:
+ print("Training complete!")
+ yield f"Done! LoRA saved to `{lora_file_path}`"
+
+
+def split_chunks(arr, step):
+ for i in range(0, len(arr), step):
+ yield arr[i:i + step]
+
+
+def cut_chunk_for_newline(chunk: str, max_length: int):
+ if '\n' not in chunk:
+ return chunk
+
+ first_newline = chunk.index('\n')
+ if first_newline < max_length:
+ chunk = chunk[first_newline + 1:]
+
+ if '\n' not in chunk:
+ return chunk
+
+ last_newline = chunk.rindex('\n')
+ if len(chunk) - last_newline < max_length:
+ chunk = chunk[:last_newline]
+
+ return chunk
+
+
+def format_time(seconds: float):
+ if seconds < 120:
+ return f"`{seconds:.0f}` seconds"
+
+ minutes = seconds / 60
+ if minutes < 120:
+ return f"`{minutes:.0f}` minutes"
+
+ hours = minutes / 60
+ return f"`{hours:.0f}` hours"
diff --git a/text-generation-webui-main/modules/ui.py b/text-generation-webui-main/modules/ui.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ddcc833942baaeec84c25d815589deadb5c2754
--- /dev/null
+++ b/text-generation-webui-main/modules/ui.py
@@ -0,0 +1,96 @@
+from pathlib import Path
+
+import gradio as gr
+import torch
+
+from modules import shared
+
+with open(Path(__file__).resolve().parent / '../css/main.css', 'r') as f:
+ css = f.read()
+with open(Path(__file__).resolve().parent / '../css/chat.css', 'r') as f:
+ chat_css = f.read()
+with open(Path(__file__).resolve().parent / '../css/main.js', 'r') as f:
+ main_js = f.read()
+with open(Path(__file__).resolve().parent / '../css/chat.js', 'r') as f:
+ chat_js = f.read()
+
+refresh_symbol = '\U0001f504' # 🔄
+theme = gr.themes.Default(
+ font=['Helvetica', 'ui-sans-serif', 'system-ui', 'sans-serif'],
+ font_mono=['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'],
+).set(
+ border_color_primary='#c5c5d2',
+ button_large_padding='6px 12px',
+ body_text_color_subdued='#484848',
+ background_fill_secondary='#eaeaea'
+)
+
+def list_model_elements():
+ elements = ['cpu_memory', 'auto_devices', 'disk', 'cpu', 'bf16', 'load_in_8bit', 'wbits', 'groupsize', 'model_type', 'pre_layer']
+ for i in range(torch.cuda.device_count()):
+ elements.append(f'gpu_memory_{i}')
+ return elements
+
+
+def list_interface_input_elements(chat=False):
+ elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens', 'preset_menu']
+ if chat:
+ elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template', 'character_menu']
+
+ elements += list_model_elements()
+ return elements
+
+
+def gather_interface_values(*args):
+ output = {}
+ for i, element in enumerate(shared.input_elements):
+ output[element] = args[i]
+
+ shared.persistent_interface_state = output
+ return output
+
+
+def apply_interface_values(state, use_persistent=False):
+ if use_persistent:
+ state = shared.persistent_interface_state
+
+ elements = list_interface_input_elements(chat=shared.is_chat())
+ if len(state) == 0:
+ return [gr.update() for k in elements] # Dummy, do nothing
+ else:
+ if use_persistent and 'mode' in state:
+ if state['mode'] == 'instruct':
+ return [state[k] if k not in ['character_menu'] else gr.update() for k in elements]
+ else:
+ return [state[k] if k not in ['instruction_template'] else gr.update() for k in elements]
+ else:
+ return [state[k] for k in elements]
+
+
+class ToolButton(gr.Button, gr.components.FormComponent):
+ """Small button with single emoji as text, fits inside gradio forms"""
+
+ def __init__(self, **kwargs):
+ super().__init__(variant="tool", **kwargs)
+
+ def get_block_name(self):
+ return "button"
+
+
+def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
+ def refresh():
+ refresh_method()
+ args = refreshed_args() if callable(refreshed_args) else refreshed_args
+
+ for k, v in args.items():
+ setattr(refresh_component, k, v)
+
+ return gr.update(**(args or {}))
+
+ refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
+ refresh_button.click(
+ fn=refresh,
+ inputs=[],
+ outputs=[refresh_component]
+ )
+ return refresh_button
diff --git a/text-generation-webui-main/presets/Contrastive Search.txt b/text-generation-webui-main/presets/Contrastive Search.txt
new file mode 100644
index 0000000000000000000000000000000000000000..832bc9caf9b744d9d9c728f88d887f012a56ba3e
--- /dev/null
+++ b/text-generation-webui-main/presets/Contrastive Search.txt
@@ -0,0 +1,3 @@
+do_sample=False
+penalty_alpha=0.6
+top_k=4
diff --git a/text-generation-webui-main/presets/Debug-deterministic.txt b/text-generation-webui-main/presets/Debug-deterministic.txt
new file mode 100644
index 0000000000000000000000000000000000000000..6673b71c8164effc401a486055b7f9a021b2acfb
--- /dev/null
+++ b/text-generation-webui-main/presets/Debug-deterministic.txt
@@ -0,0 +1 @@
+do_sample=False
diff --git a/text-generation-webui-main/presets/Default.txt b/text-generation-webui-main/presets/Default.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d28ce62f0e36d1f7824fe40d6e40018c9d78ea21
--- /dev/null
+++ b/text-generation-webui-main/presets/Default.txt
@@ -0,0 +1,6 @@
+do_sample=True
+top_p=0.5
+top_k=40
+temperature=0.7
+repetition_penalty=1.2
+typical_p=1.0
diff --git a/text-generation-webui-main/presets/Kobold-Godlike.txt b/text-generation-webui-main/presets/Kobold-Godlike.txt
new file mode 100644
index 0000000000000000000000000000000000000000..0ba5b794b6d0130a1fa1d918bda9a276f7d23367
--- /dev/null
+++ b/text-generation-webui-main/presets/Kobold-Godlike.txt
@@ -0,0 +1,6 @@
+do_sample=True
+top_p=0.5
+top_k=0
+temperature=0.7
+repetition_penalty=1.1
+typical_p=0.19
diff --git a/text-generation-webui-main/presets/Kobold-Liminal Drift.txt b/text-generation-webui-main/presets/Kobold-Liminal Drift.txt
new file mode 100644
index 0000000000000000000000000000000000000000..be4dd3bd7a70af2d4eb6c847bed6bedee5379dce
--- /dev/null
+++ b/text-generation-webui-main/presets/Kobold-Liminal Drift.txt
@@ -0,0 +1,6 @@
+do_sample=True
+top_p=1.0
+top_k=0
+temperature=0.66
+repetition_penalty=1.1
+typical_p=0.6
diff --git a/text-generation-webui-main/presets/LLaMA-Precise.txt b/text-generation-webui-main/presets/LLaMA-Precise.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8098b390a097fc9438a2a82ec2bdd58adb2a771b
--- /dev/null
+++ b/text-generation-webui-main/presets/LLaMA-Precise.txt
@@ -0,0 +1,6 @@
+do_sample=True
+top_p=0.1
+top_k=40
+temperature=0.7
+repetition_penalty=1.18
+typical_p=1.0
diff --git a/text-generation-webui-main/presets/Naive.txt b/text-generation-webui-main/presets/Naive.txt
new file mode 100644
index 0000000000000000000000000000000000000000..aa8c058224c533f4084e230f6bbf77b63d5e81ea
--- /dev/null
+++ b/text-generation-webui-main/presets/Naive.txt
@@ -0,0 +1,4 @@
+do_sample=True
+temperature=0.7
+top_p=0.85
+top_k=50
diff --git a/text-generation-webui-main/presets/NovelAI-Best Guess.txt b/text-generation-webui-main/presets/NovelAI-Best Guess.txt
new file mode 100644
index 0000000000000000000000000000000000000000..db3fa75b2a11d7e29b108177f9894e82d1e52126
--- /dev/null
+++ b/text-generation-webui-main/presets/NovelAI-Best Guess.txt
@@ -0,0 +1,6 @@
+do_sample=True
+top_p=0.9
+top_k=100
+temperature=0.8
+repetition_penalty=1.15
+typical_p=1.0
diff --git a/text-generation-webui-main/presets/NovelAI-Decadence.txt b/text-generation-webui-main/presets/NovelAI-Decadence.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d3109f3e3f3a021810d171a0b98f615766b57e4b
--- /dev/null
+++ b/text-generation-webui-main/presets/NovelAI-Decadence.txt
@@ -0,0 +1,6 @@
+do_sample=True
+top_p=1.0
+top_k=100
+temperature=2
+repetition_penalty=1
+typical_p=0.97
diff --git a/text-generation-webui-main/presets/NovelAI-Genesis.txt b/text-generation-webui-main/presets/NovelAI-Genesis.txt
new file mode 100644
index 0000000000000000000000000000000000000000..cc7376b3b981a260448a65cd3c00c7b3904308e2
--- /dev/null
+++ b/text-generation-webui-main/presets/NovelAI-Genesis.txt
@@ -0,0 +1,6 @@
+do_sample=True
+top_p=0.98
+top_k=0
+temperature=0.63
+repetition_penalty=1.05
+typical_p=1.0
diff --git a/text-generation-webui-main/presets/NovelAI-Lycaenidae.txt b/text-generation-webui-main/presets/NovelAI-Lycaenidae.txt
new file mode 100644
index 0000000000000000000000000000000000000000..0134569cef76bc0de6b3dc7885d94d9d9afdfd62
--- /dev/null
+++ b/text-generation-webui-main/presets/NovelAI-Lycaenidae.txt
@@ -0,0 +1,6 @@
+do_sample=True
+top_p=0.85
+top_k=12
+temperature=2
+repetition_penalty=1.15
+typical_p=1.0
diff --git a/text-generation-webui-main/presets/NovelAI-Ouroboros.txt b/text-generation-webui-main/presets/NovelAI-Ouroboros.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1e944b54e78e1f63bd4bb6f56a717e0fec751c6b
--- /dev/null
+++ b/text-generation-webui-main/presets/NovelAI-Ouroboros.txt
@@ -0,0 +1,6 @@
+do_sample=True
+top_p=1.0
+top_k=100
+temperature=1.07
+repetition_penalty=1.05
+typical_p=1.0
diff --git a/text-generation-webui-main/presets/NovelAI-Pleasing Results.txt b/text-generation-webui-main/presets/NovelAI-Pleasing Results.txt
new file mode 100644
index 0000000000000000000000000000000000000000..330114a25db6d194dbc8689bf5476a81f649cf64
--- /dev/null
+++ b/text-generation-webui-main/presets/NovelAI-Pleasing Results.txt
@@ -0,0 +1,6 @@
+do_sample=True
+top_p=1.0
+top_k=0
+temperature=0.44
+repetition_penalty=1.15
+typical_p=1.0
diff --git a/text-generation-webui-main/presets/NovelAI-Sphinx Moth.txt b/text-generation-webui-main/presets/NovelAI-Sphinx Moth.txt
new file mode 100644
index 0000000000000000000000000000000000000000..bace1e24b5dcc64fdde99097930f41a991e91b8e
--- /dev/null
+++ b/text-generation-webui-main/presets/NovelAI-Sphinx Moth.txt
@@ -0,0 +1,6 @@
+do_sample=True
+top_p=0.18
+top_k=30
+temperature=2.0
+repetition_penalty=1.15
+typical_p=1.0
diff --git a/text-generation-webui-main/presets/NovelAI-Storywriter.txt b/text-generation-webui-main/presets/NovelAI-Storywriter.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2df5f8181458c642ed4691925ade3d542de5391c
--- /dev/null
+++ b/text-generation-webui-main/presets/NovelAI-Storywriter.txt
@@ -0,0 +1,6 @@
+do_sample=True
+top_p=0.73
+top_k=0
+temperature=0.72
+repetition_penalty=1.1
+typical_p=1.0
diff --git a/text-generation-webui-main/presets/Verbose (Beam Search).txt b/text-generation-webui-main/presets/Verbose (Beam Search).txt
new file mode 100644
index 0000000000000000000000000000000000000000..464a4a5f0dda62348fda2cbbba4a98036c744d5c
--- /dev/null
+++ b/text-generation-webui-main/presets/Verbose (Beam Search).txt
@@ -0,0 +1,9 @@
+num_beams=10
+min_length=200
+length_penalty=1.4
+no_repeat_ngram_size=2
+early_stopping=True
+temperature=0.7
+top_k=150
+top_p=0.92
+repetition_penalty=4.5
diff --git a/text-generation-webui-main/prompts/Alpaca.txt b/text-generation-webui-main/prompts/Alpaca.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8434a80c3bcf35c5c62698ae31174f20f822cb6d
--- /dev/null
+++ b/text-generation-webui-main/prompts/Alpaca.txt
@@ -0,0 +1,6 @@
+Below is an instruction that describes a task. Write a response that appropriately completes the request.
+### Instruction:
+Write a poem about the transformers Python library.
+Mention the word "large language models" in that poem.
+### Response:
+
diff --git a/text-generation-webui-main/prompts/GPT-4chan.txt b/text-generation-webui-main/prompts/GPT-4chan.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1bc8c7f4613f982e3dfa367562a764cf5bd4c73b
--- /dev/null
+++ b/text-generation-webui-main/prompts/GPT-4chan.txt
@@ -0,0 +1,6 @@
+-----
+--- 865467536
+Hello, AI frens!
+How are you doing on this fine day?
+--- 865467537
+
diff --git a/text-generation-webui-main/prompts/Open Assistant.txt b/text-generation-webui-main/prompts/Open Assistant.txt
new file mode 100644
index 0000000000000000000000000000000000000000..cf1ae4a2d0723afc8adee24fa496bafeaba0f492
--- /dev/null
+++ b/text-generation-webui-main/prompts/Open Assistant.txt
@@ -0,0 +1 @@
+<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>
diff --git a/text-generation-webui-main/prompts/QA.txt b/text-generation-webui-main/prompts/QA.txt
new file mode 100644
index 0000000000000000000000000000000000000000..32b0e2350f3c0a7f447dcd1aba11d6ae2247e5a8
--- /dev/null
+++ b/text-generation-webui-main/prompts/QA.txt
@@ -0,0 +1,4 @@
+Common sense questions and answers
+
+Question:
+Factual answer:
diff --git a/text-generation-webui-main/requirements.txt b/text-generation-webui-main/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2ee5274e9621ee6ccf4915f979ddf0ff06d49f34
--- /dev/null
+++ b/text-generation-webui-main/requirements.txt
@@ -0,0 +1,20 @@
+accelerate==0.18.0
+colorama
+datasets
+flexgen==0.1.7
+gradio==3.25.0
+markdown
+numpy
+pandas
+Pillow>=9.5.0
+pyyaml
+requests
+rwkv==0.7.3
+safetensors==0.3.0
+sentencepiece
+tqdm
+git+https://github.com/huggingface/peft
+transformers==4.28.1
+bitsandbytes==0.38.1; platform_system != "Windows"
+llama-cpp-python==0.1.36; platform_system != "Windows"
+https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.36/llama_cpp_python-0.1.36-cp310-cp310-win_amd64.whl; platform_system == "Windows"
diff --git a/text-generation-webui-main/server.py b/text-generation-webui-main/server.py
new file mode 100644
index 0000000000000000000000000000000000000000..da786349e4650855d446e243f2bb9f12ab92b566
--- /dev/null
+++ b/text-generation-webui-main/server.py
@@ -0,0 +1,929 @@
+import os
+import requests
+import warnings
+
+os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
+os.environ['BITSANDBYTES_NOWELCOME'] = '1'
+warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
+
+# This is a hack to prevent Gradio from phoning home when it gets imported
+def my_get(url, **kwargs):
+ print('Gradio HTTP request redirected to localhost :)')
+ kwargs.setdefault('allow_redirects', True)
+ return requests.api.request('get', 'http://127.0.0.1/', **kwargs)
+
+original_get = requests.get
+requests.get = my_get
+import gradio as gr
+requests.get = original_get
+
+# This fixes LaTeX rendering on some systems
+import matplotlib
+matplotlib.use('Agg')
+
+import importlib
+import io
+import json
+import math
+import os
+import re
+import sys
+import time
+import traceback
+import zipfile
+from datetime import datetime
+from functools import partial
+from pathlib import Path
+
+import psutil
+import torch
+import yaml
+from PIL import Image
+
+import modules.extensions as extensions_module
+from modules import chat, shared, training, ui
+from modules.html_generator import chat_html_wrapper
+from modules.LoRA import add_lora_to_model
+from modules.models import load_model, load_soft_prompt, unload_model
+from modules.text_generation import (encode, generate_reply,
+ stop_everything_event)
+
+
+def get_available_models():
+ if shared.args.flexgen:
+ return sorted([re.sub('-np$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if item.name.endswith('-np')], key=str.lower)
+ else:
+ return sorted([re.sub('.pth$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json', '.yaml'))], key=str.lower)
+
+
+def get_available_presets():
+ return sorted(set((k.stem for k in Path('presets').glob('*.txt'))), key=str.lower)
+
+
+def get_available_prompts():
+ prompts = []
+ prompts += sorted(set((k.stem for k in Path('prompts').glob('[0-9]*.txt'))), key=str.lower, reverse=True)
+ prompts += sorted(set((k.stem for k in Path('prompts').glob('*.txt'))), key=str.lower)
+ prompts += ['None']
+ return prompts
+
+
+def get_available_characters():
+ paths = (x for x in Path('characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
+ return ['None'] + sorted(set((k.stem for k in paths if k.stem != "instruction-following")), key=str.lower)
+
+
+def get_available_instruction_templates():
+ path = "characters/instruction-following"
+ paths = []
+ if os.path.exists(path):
+ paths = (x for x in Path(path).iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
+ return ['None'] + sorted(set((k.stem for k in paths)), key=str.lower)
+
+
+def get_available_extensions():
+ return sorted(set(map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
+
+
+def get_available_softprompts():
+ return ['None'] + sorted(set((k.stem for k in Path('softprompts').glob('*.zip'))), key=str.lower)
+
+
+def get_available_loras():
+ return sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
+
+
+def load_model_wrapper(selected_model):
+ try:
+ yield f"Loading {selected_model}..."
+ shared.model_name = selected_model
+ unload_model()
+ if selected_model != '':
+ shared.model, shared.tokenizer = load_model(shared.model_name)
+
+ yield f"Successfully loaded {selected_model}"
+ except:
+ yield traceback.format_exc()
+
+
+def load_lora_wrapper(selected_loras):
+ yield ("Applying the following LoRAs to {}:\n\n{}".format(shared.model_name, '\n'.join(selected_loras)))
+ add_lora_to_model(selected_loras)
+ yield ("Successfuly applied the LoRAs")
+
+
+def load_preset_values(preset_menu, state, return_dict=False):
+ generate_params = {
+ 'do_sample': True,
+ 'temperature': 1,
+ 'top_p': 1,
+ 'typical_p': 1,
+ 'repetition_penalty': 1,
+ 'encoder_repetition_penalty': 1,
+ 'top_k': 50,
+ 'num_beams': 1,
+ 'penalty_alpha': 0,
+ 'min_length': 0,
+ 'length_penalty': 1,
+ 'no_repeat_ngram_size': 0,
+ 'early_stopping': False,
+ }
+ with open(Path(f'presets/{preset_menu}.txt'), 'r') as infile:
+ preset = infile.read()
+ for i in preset.splitlines():
+ i = i.rstrip(',').strip().split('=')
+ if len(i) == 2 and i[0].strip() != 'tokens':
+ generate_params[i[0].strip()] = eval(i[1].strip())
+ generate_params['temperature'] = min(1.99, generate_params['temperature'])
+
+ if return_dict:
+ return generate_params
+ else:
+ state.update(generate_params)
+ return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
+
+
+def upload_soft_prompt(file):
+ with zipfile.ZipFile(io.BytesIO(file)) as zf:
+ zf.extract('meta.json')
+ j = json.loads(open('meta.json', 'r').read())
+ name = j['name']
+ Path('meta.json').unlink()
+
+ with open(Path(f'softprompts/{name}.zip'), 'wb') as f:
+ f.write(file)
+
+ return name
+
+
+def save_prompt(text):
+ fname = f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}.txt"
+ with open(Path(f'prompts/{fname}'), 'w', encoding='utf-8') as f:
+ f.write(text)
+ return f"Saved to prompts/{fname}"
+
+
+def load_prompt(fname):
+ if fname in ['None', '']:
+ return ''
+ else:
+ with open(Path(f'prompts/{fname}.txt'), 'r', encoding='utf-8') as f:
+ text = f.read()
+ if text[-1] == '\n':
+ text = text[:-1]
+ return text
+
+
+def count_tokens(text):
+ tokens = len(encode(text)[0])
+ return f'{tokens} tokens in the input.'
+
+
+def download_model_wrapper(repo_id):
+ try:
+ downloader = importlib.import_module("download-model")
+
+ model = repo_id
+ branch = "main"
+ check = False
+
+ yield ("Cleaning up the model/branch names")
+ model, branch = downloader.sanitize_model_and_branch_names(model, branch)
+
+ yield ("Getting the download links from Hugging Face")
+ links, sha256, is_lora = downloader.get_download_links_from_huggingface(model, branch, text_only=False)
+
+ yield ("Getting the output folder")
+ output_folder = downloader.get_output_folder(model, branch, is_lora)
+
+ if check:
+ yield ("Checking previously downloaded files")
+ downloader.check_model_files(model, branch, links, sha256, output_folder)
+ else:
+ yield (f"Downloading files to {output_folder}")
+ downloader.download_model_files(model, branch, links, sha256, output_folder, threads=1)
+ yield ("Done!")
+ except:
+ yield traceback.format_exc()
+
+
+# Update the command-line arguments based on the interface values
+def update_model_parameters(state, initial=False):
+ elements = ui.list_model_elements() # the names of the parameters
+ gpu_memories = []
+
+ for i, element in enumerate(elements):
+ if element not in state:
+ continue
+
+ value = state[element]
+ if element.startswith('gpu_memory'):
+ gpu_memories.append(value)
+ continue
+
+ if initial and vars(shared.args)[element] != vars(shared.args_defaults)[element]:
+ continue
+
+ # Setting null defaults
+ if element in ['wbits', 'groupsize', 'model_type'] and value == 'None':
+ value = vars(shared.args_defaults)[element]
+ elif element in ['cpu_memory'] and value == 0:
+ value = vars(shared.args_defaults)[element]
+
+ # Making some simple conversions
+ if element in ['wbits', 'groupsize', 'pre_layer']:
+ value = int(value)
+ elif element == 'cpu_memory' and value is not None:
+ value = f"{value}MiB"
+
+ setattr(shared.args, element, value)
+
+ found_positive = False
+ for i in gpu_memories:
+ if i > 0:
+ found_positive = True
+ break
+
+ if not (initial and vars(shared.args)['gpu_memory'] != vars(shared.args_defaults)['gpu_memory']):
+ if found_positive:
+ shared.args.gpu_memory = [f"{i}MiB" for i in gpu_memories]
+ else:
+ shared.args.gpu_memory = None
+
+
+def get_model_specific_settings(model):
+ settings = shared.model_config
+ model_settings = {}
+
+ for pat in settings:
+ if re.match(pat.lower(), model.lower()):
+ for k in settings[pat]:
+ model_settings[k] = settings[pat][k]
+
+ return model_settings
+
+
+def load_model_specific_settings(model, state, return_dict=False):
+ model_settings = get_model_specific_settings(model)
+ for k in model_settings:
+ if k in state:
+ state[k] = model_settings[k]
+
+ return state
+
+
+def save_model_settings(model, state):
+ if model == 'None':
+ yield ("Not saving the settings because no model is loaded.")
+ return
+
+ with Path(f'{shared.args.model_dir}/config-user.yaml') as p:
+ if p.exists():
+ user_config = yaml.safe_load(open(p, 'r').read())
+ else:
+ user_config = {}
+
+ if model not in user_config:
+ user_config[model] = {}
+
+ for k in ui.list_model_elements():
+ user_config[model][k] = state[k]
+
+ with open(p, 'w') as f:
+ f.write(yaml.dump(user_config))
+
+ yield (f"Settings for {model} saved to {p}")
+
+
+def create_model_menus():
+ # Finding the default values for the GPU and CPU memories
+ total_mem = []
+ for i in range(torch.cuda.device_count()):
+ total_mem.append(math.floor(torch.cuda.get_device_properties(i).total_memory / (1024 * 1024)))
+
+ default_gpu_mem = []
+ if shared.args.gpu_memory is not None and len(shared.args.gpu_memory) > 0:
+ for i in shared.args.gpu_memory:
+ if 'mib' in i.lower():
+ default_gpu_mem.append(int(re.sub('[a-zA-Z ]', '', i)))
+ else:
+ default_gpu_mem.append(int(re.sub('[a-zA-Z ]', '', i)) * 1000)
+ while len(default_gpu_mem) < len(total_mem):
+ default_gpu_mem.append(0)
+
+ total_cpu_mem = math.floor(psutil.virtual_memory().total / (1024 * 1024))
+ if shared.args.cpu_memory is not None:
+ default_cpu_mem = re.sub('[a-zA-Z ]', '', shared.args.cpu_memory)
+ else:
+ default_cpu_mem = 0
+
+ with gr.Row():
+ with gr.Column():
+ with gr.Row():
+ with gr.Column():
+ with gr.Row():
+ shared.gradio['model_menu'] = gr.Dropdown(choices=get_available_models(), value=shared.model_name, label='Model')
+ ui.create_refresh_button(shared.gradio['model_menu'], lambda: None, lambda: {'choices': get_available_models()}, 'refresh-button')
+
+ with gr.Column():
+ with gr.Row():
+ shared.gradio['lora_menu'] = gr.Dropdown(multiselect=True, choices=get_available_loras(), value=shared.lora_names, label='LoRA(s)')
+ ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras(), 'value': shared.lora_names}, 'refresh-button')
+
+ with gr.Column():
+ with gr.Row():
+ shared.gradio['lora_menu_apply'] = gr.Button(value='Apply the selected LoRAs')
+ with gr.Row():
+ unload = gr.Button("Unload the model")
+ reload = gr.Button("Reload the model")
+ save_settings = gr.Button("Save settings for this model")
+
+ with gr.Row():
+ with gr.Column():
+ with gr.Box():
+ gr.Markdown('Transformers parameters')
+ with gr.Row():
+ with gr.Column():
+ for i in range(len(total_mem)):
+ shared.gradio[f'gpu_memory_{i}'] = gr.Slider(label=f"gpu-memory in MiB for device :{i}", maximum=total_mem[i], value=default_gpu_mem[i])
+ shared.gradio['cpu_memory'] = gr.Slider(label="cpu-memory in MiB", maximum=total_cpu_mem, value=default_cpu_mem)
+
+ with gr.Column():
+ shared.gradio['auto_devices'] = gr.Checkbox(label="auto-devices", value=shared.args.auto_devices)
+ shared.gradio['disk'] = gr.Checkbox(label="disk", value=shared.args.disk)
+ shared.gradio['cpu'] = gr.Checkbox(label="cpu", value=shared.args.cpu)
+ shared.gradio['bf16'] = gr.Checkbox(label="bf16", value=shared.args.bf16)
+ shared.gradio['load_in_8bit'] = gr.Checkbox(label="load-in-8bit", value=shared.args.load_in_8bit)
+
+ with gr.Column():
+ with gr.Box():
+ gr.Markdown('GPTQ parameters')
+ with gr.Row():
+ with gr.Column():
+ shared.gradio['wbits'] = gr.Dropdown(label="wbits", choices=["None", 1, 2, 3, 4, 8], value=shared.args.wbits if shared.args.wbits > 0 else "None")
+ shared.gradio['groupsize'] = gr.Dropdown(label="groupsize", choices=["None", 32, 64, 128], value=shared.args.groupsize if shared.args.groupsize > 0 else "None")
+
+ with gr.Column():
+ shared.gradio['model_type'] = gr.Dropdown(label="model_type", choices=["None", "llama", "opt", "gptj"], value=shared.args.model_type or "None")
+ shared.gradio['pre_layer'] = gr.Slider(label="pre_layer", minimum=0, maximum=100, value=shared.args.pre_layer)
+
+ with gr.Row():
+ with gr.Column():
+ shared.gradio['custom_model_menu'] = gr.Textbox(label="Download custom model or LoRA", info="Enter Hugging Face username/model path, e.g: facebook/galactica-125m")
+ shared.gradio['download_model_button'] = gr.Button("Download")
+
+ with gr.Column():
+ shared.gradio['model_status'] = gr.Markdown('No model is loaded' if shared.model_name == 'None' else 'Ready')
+
+ # In this event handler, the interface state is read and updated
+ # with the model defaults (if any), and then the model is loaded
+ shared.gradio['model_menu'].change(
+ ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
+ load_model_specific_settings, [shared.gradio[k] for k in ['model_menu', 'interface_state']], shared.gradio['interface_state']).then(
+ ui.apply_interface_values, shared.gradio['interface_state'], [shared.gradio[k] for k in ui.list_interface_input_elements(chat=shared.is_chat())], show_progress=False).then(
+ update_model_parameters, shared.gradio['interface_state'], None).then(
+ load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_status'], show_progress=True)
+
+ unload.click(
+ unload_model, None, None).then(
+ lambda: "Model unloaded", None, shared.gradio['model_status'])
+
+ reload.click(
+ unload_model, None, None).then(
+ ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
+ update_model_parameters, shared.gradio['interface_state'], None).then(
+ load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_status'], show_progress=False)
+
+ save_settings.click(
+ ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
+ save_model_settings, [shared.gradio[k] for k in ['model_menu', 'interface_state']], shared.gradio['model_status'], show_progress=False)
+
+ shared.gradio['lora_menu_apply'].click(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['model_status'], show_progress=False)
+ shared.gradio['download_model_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['model_status'], show_progress=False)
+
+
+def create_settings_menus(default_preset):
+
+ generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True)
+
+ with gr.Row():
+ with gr.Column():
+ with gr.Row():
+ shared.gradio['preset_menu'] = gr.Dropdown(choices=get_available_presets(), value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
+ ui.create_refresh_button(shared.gradio['preset_menu'], lambda: None, lambda: {'choices': get_available_presets()}, 'refresh-button')
+ with gr.Column():
+ shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)')
+
+ with gr.Row():
+ with gr.Column():
+ with gr.Box():
+ gr.Markdown('Custom generation parameters ([click here to view technical documentation](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig))')
+ with gr.Row():
+ with gr.Column():
+ shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature', info='Primary factor to control randomness of outputs. 0 = deterministic (only the most likely token is used). Higher value = more randomness.')
+ shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=generate_params['top_p'], step=0.01, label='top_p', info='If not set to 1, select tokens with probabilities adding up to less than this number. Higher value = higher range of possible random results.')
+ shared.gradio['top_k'] = gr.Slider(0, 200, value=generate_params['top_k'], step=1, label='top_k', info='Similar to top_p, but select instead only the top_k most likely tokens. Higher value = higher range of possible random results.')
+ shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=generate_params['typical_p'], step=0.01, label='typical_p', info='If not set to 1, select only tokens that are at least this much more likely to appear than random tokens, given the prior text.')
+ with gr.Column():
+ shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'], step=0.01, label='repetition_penalty', info='Exponential penalty factor for repeating prior tokens. 1 means no penalty, higher value = less repetition, lower value = more repetition.')
+ shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'], step=0.01, label='encoder_repetition_penalty', info='Also known as the "Hallucinations filter". Used to penalize tokens that are *not* in the prior text. Higher value = more likely to stay in context, lower value = more likely to diverge.')
+ shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size', info='If not set to 0, specifies the length of token sets that are completely blocked from repeating at all. Higher values = blocks larger phrases, lower values = blocks words or letters from repeating. Only 0 or high values are a good idea in most cases.')
+ shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'], label='min_length', info='Minimum generation length in tokens.')
+ shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
+ with gr.Column():
+ with gr.Box():
+ gr.Markdown('Contrastive search')
+ shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha')
+
+ gr.Markdown('Beam search (uses a lot of VRAM)')
+ with gr.Row():
+ with gr.Column():
+ shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams')
+ shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
+ with gr.Column():
+ shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
+
+ with gr.Box():
+ with gr.Row():
+ with gr.Column():
+ shared.gradio['truncation_length'] = gr.Slider(value=shared.settings['truncation_length'], minimum=shared.settings['truncation_length_min'], maximum=shared.settings['truncation_length_max'], step=1, label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.')
+ shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=1, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='In addition to the defaults. Written between "" and separated by commas. For instance: "\\nYour Assistant:", "\\nThe assistant:"')
+ with gr.Column():
+ shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.')
+ shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='Forces the model to never end the generation prematurely.')
+
+ shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.')
+
+ with gr.Accordion('Soft prompt', open=False):
+ with gr.Row():
+ shared.gradio['softprompts_menu'] = gr.Dropdown(choices=get_available_softprompts(), value='None', label='Soft prompt')
+ ui.create_refresh_button(shared.gradio['softprompts_menu'], lambda: None, lambda: {'choices': get_available_softprompts()}, 'refresh-button')
+
+ gr.Markdown('Upload a soft prompt (.zip format):')
+ with gr.Row():
+ shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
+
+ shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'interface_state']], [shared.gradio[k] for k in ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
+ shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True)
+ shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu'])
+
+
+def set_interface_arguments(interface_mode, extensions, bool_active):
+ modes = ["default", "notebook", "chat", "cai_chat"]
+ cmd_list = vars(shared.args)
+ bool_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes]
+
+ shared.args.extensions = extensions
+ for k in modes[1:]:
+ setattr(shared.args, k, False)
+ if interface_mode != "default":
+ setattr(shared.args, interface_mode, True)
+
+ for k in bool_list:
+ setattr(shared.args, k, False)
+ for k in bool_active:
+ setattr(shared.args, k, True)
+
+ shared.need_restart = True
+
+
+def create_interface():
+
+ # Defining some variables
+ gen_events = []
+ default_preset = shared.settings['presets'][next((k for k in shared.settings['presets'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
+ if len(shared.lora_names) == 1:
+ default_text = load_prompt(shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_names[0].lower())), 'default')])
+ else:
+ default_text = load_prompt(shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')])
+ title = 'Text generation web UI'
+
+ # Authentication variables
+ auth = None
+ if shared.args.gradio_auth_path is not None:
+ gradio_auth_creds = []
+ with open(shared.args.gradio_auth_path, 'r', encoding="utf8") as file:
+ for line in file.readlines():
+ gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()]
+ auth = [tuple(cred.split(':')) for cred in gradio_auth_creds]
+
+ # Importing the extension files and executing their setup() functions
+ if shared.args.extensions is not None and len(shared.args.extensions) > 0:
+ extensions_module.load_extensions()
+
+ with gr.Blocks(css=ui.css if not shared.is_chat() else ui.css + ui.chat_css, analytics_enabled=False, title=title, theme=ui.theme) as shared.gradio['interface']:
+
+ # Create chat mode interface
+ if shared.is_chat():
+ shared.input_elements = ui.list_interface_input_elements(chat=True)
+ shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
+ shared.gradio['Chat input'] = gr.State()
+
+ with gr.Tab('Text generation', elem_id='main'):
+ shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], 'cai-chat'))
+ shared.gradio['textbox'] = gr.Textbox(label='Input')
+ with gr.Row():
+ shared.gradio['Stop'] = gr.Button('Stop', elem_id='stop')
+ shared.gradio['Generate'] = gr.Button('Generate', elem_id='Generate', variant='primary')
+ shared.gradio['Continue'] = gr.Button('Continue')
+
+ with gr.Row():
+ shared.gradio['Copy last reply'] = gr.Button('Copy last reply')
+ shared.gradio['Regenerate'] = gr.Button('Regenerate')
+ shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
+
+ with gr.Row():
+ shared.gradio['Impersonate'] = gr.Button('Impersonate')
+ shared.gradio['Send dummy message'] = gr.Button('Send dummy message')
+ shared.gradio['Send dummy reply'] = gr.Button('Send dummy reply')
+
+ with gr.Row():
+ shared.gradio['Remove last'] = gr.Button('Remove last')
+ shared.gradio['Clear history'] = gr.Button('Clear history')
+ shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant='stop', visible=False)
+ shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
+
+ shared.gradio['mode'] = gr.Radio(choices=['cai-chat', 'chat', 'instruct'], value=shared.settings['mode'], label='Mode')
+ shared.gradio['instruction_template'] = gr.Dropdown(choices=get_available_instruction_templates(), label='Instruction template', value=shared.settings['instruction_template'], visible=shared.settings['mode'] == 'instruct', info='Change this according to the model/LoRA that you are using.')
+
+ with gr.Tab('Character', elem_id='chat-settings'):
+ with gr.Row():
+ with gr.Column(scale=8):
+ shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name')
+ shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Character\'s name')
+ shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=4, label='Greeting')
+ shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=4, label='Context')
+ shared.gradio['end_of_turn'] = gr.Textbox(value=shared.settings['end_of_turn'], lines=1, label='End of turn string')
+
+ with gr.Column(scale=1):
+ shared.gradio['character_picture'] = gr.Image(label='Character picture', type='pil')
+ shared.gradio['your_picture'] = gr.Image(label='Your picture', type='pil', value=Image.open(Path('cache/pfp_me.png')) if Path('cache/pfp_me.png').exists() else None)
+
+ with gr.Row():
+ shared.gradio['character_menu'] = gr.Dropdown(choices=get_available_characters(), value='None', label='Character', elem_id='character-menu')
+ ui.create_refresh_button(shared.gradio['character_menu'], lambda: None, lambda: {'choices': get_available_characters()}, 'refresh-button')
+
+ with gr.Row():
+ with gr.Tab('Chat history'):
+ with gr.Row():
+ with gr.Column():
+ gr.Markdown('Upload')
+ shared.gradio['upload_chat_history'] = gr.File(type='binary', file_types=['.json', '.txt'])
+
+ with gr.Column():
+ gr.Markdown('Download')
+ shared.gradio['download'] = gr.File()
+ shared.gradio['download_button'] = gr.Button(value='Click me')
+
+ with gr.Tab('Upload character'):
+ gr.Markdown('# JSON format')
+ with gr.Row():
+ with gr.Column():
+ gr.Markdown('1. Select the JSON file')
+ shared.gradio['upload_json'] = gr.File(type='binary', file_types=['.json'])
+
+ with gr.Column():
+ gr.Markdown('2. Select your character\'s profile picture (optional)')
+ shared.gradio['upload_img_bot'] = gr.File(type='binary', file_types=['image'])
+
+ shared.gradio['Upload character'] = gr.Button(value='Submit')
+ gr.Markdown('# TavernAI PNG format')
+ shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image'])
+
+ with gr.Tab("Parameters", elem_id="parameters"):
+ with gr.Box():
+ gr.Markdown("Chat parameters")
+ with gr.Row():
+ with gr.Column():
+ shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
+ shared.gradio['chat_prompt_size'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
+
+ with gr.Column():
+ shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
+ shared.gradio['stop_at_newline'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character')
+
+ create_settings_menus(default_preset)
+
+ # Create notebook mode interface
+ elif shared.args.notebook:
+ shared.input_elements = ui.list_interface_input_elements(chat=False)
+ shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
+ shared.gradio['last_input'] = gr.State('')
+ with gr.Tab("Text generation", elem_id="main"):
+ with gr.Row():
+ with gr.Column(scale=4):
+ with gr.Tab('Raw'):
+ shared.gradio['textbox'] = gr.Textbox(value=default_text, elem_classes="textbox", lines=27)
+
+ with gr.Tab('Markdown'):
+ shared.gradio['markdown'] = gr.Markdown()
+
+ with gr.Tab('HTML'):
+ shared.gradio['html'] = gr.HTML()
+
+ with gr.Row():
+ shared.gradio['Generate'] = gr.Button('Generate', variant='primary', elem_classes="small-button")
+ shared.gradio['Stop'] = gr.Button('Stop', elem_classes="small-button")
+ shared.gradio['Undo'] = gr.Button('Undo', elem_classes="small-button")
+ shared.gradio['Regenerate'] = gr.Button('Regenerate', elem_classes="small-button")
+
+ with gr.Column(scale=1):
+ gr.HTML('