zetavg commited on
Commit
9cd5ad7
1 Parent(s): a5e11b9
llama_lora/ui/css_styles.py CHANGED
@@ -1,4 +1,6 @@
1
- css_styles = []
 
 
2
 
3
 
4
  def get_css_styles():
 
1
+ from typing import List
2
+
3
+ css_styles: List[str] = []
4
 
5
 
6
  def get_css_styles():
llama_lora/ui/inference_ui.py CHANGED
@@ -556,9 +556,10 @@ def inference_ui():
556
  elem_id="inference_inference_raw_output_accordion"
557
  ) as raw_output_group:
558
  inference_raw_output = gr.Code(
559
- label="Raw Output",
560
- show_label=False,
561
  language="json",
 
562
  interactive=False,
563
  elem_id="inference_raw_output")
564
 
 
556
  elem_id="inference_inference_raw_output_accordion"
557
  ) as raw_output_group:
558
  inference_raw_output = gr.Code(
559
+ # label="Raw Output",
560
+ label="Tensor",
561
  language="json",
562
+ lines=8,
563
  interactive=False,
564
  elem_id="inference_raw_output")
565
 
llama_lora/ui/tokenizer_ui.py CHANGED
@@ -46,6 +46,7 @@ def tokenizer_ui():
46
  encoded_tokens = gr.Code(
47
  label="Encoded Tokens (JSON)",
48
  language="json",
 
49
  value=sample_encoded_tokens_value,
50
  elem_id="tokenizer_encoded_tokens_input_textbox")
51
  decode_btn = gr.Button("Decode ➡️")
@@ -54,6 +55,7 @@ def tokenizer_ui():
54
  with gr.Column():
55
  decoded_tokens = gr.Code(
56
  label="Decoded Tokens",
 
57
  value=sample_decoded_text_value,
58
  elem_id="tokenizer_decoded_text_input_textbox")
59
  encode_btn = gr.Button("⬅️ Encode")
 
46
  encoded_tokens = gr.Code(
47
  label="Encoded Tokens (JSON)",
48
  language="json",
49
+ lines=10,
50
  value=sample_encoded_tokens_value,
51
  elem_id="tokenizer_encoded_tokens_input_textbox")
52
  decode_btn = gr.Button("Decode ➡️")
 
55
  with gr.Column():
56
  decoded_tokens = gr.Code(
57
  label="Decoded Tokens",
58
+ lines=10,
59
  value=sample_decoded_text_value,
60
  elem_id="tokenizer_decoded_text_input_textbox")
61
  encode_btn = gr.Button("⬅️ Encode")
llama_lora/utils/prompter.py CHANGED
@@ -7,7 +7,7 @@ import json
7
  import os.path as osp
8
  import importlib
9
  import itertools
10
- from typing import Union, List
11
 
12
  from ..config import Config
13
  from ..globals import Global
@@ -38,9 +38,10 @@ class Prompter(object):
38
  raise ValueError(f"Can't read {file_path}")
39
 
40
  if ext == ".py":
41
- template_module_spec = importlib.util.spec_from_file_location(
 
42
  "template_module", file_path)
43
- template_module = importlib.util.module_from_spec(
44
  template_module_spec)
45
  template_module_spec.loader.exec_module(template_module)
46
  self.template_module = template_module
@@ -67,7 +68,7 @@ class Prompter(object):
67
 
68
  def generate_prompt(
69
  self,
70
- variables: List[Union[None, str]] = [],
71
  # instruction: str,
72
  # input: Union[None, str] = None,
73
  label: Union[None, str] = None,
@@ -75,10 +76,14 @@ class Prompter(object):
75
  if self.template_name == "None":
76
  if type(variables) == list:
77
  res = get_val(variables, 0, "")
78
- else:
79
  res = variables.get("prompt", "")
 
 
80
  elif "variables" in self.template:
81
  variable_names = self.template.get("variables")
 
 
82
  if self.template_module:
83
  if type(variables) == list:
84
  variables = {k: v for k, v in zip(
 
7
  import os.path as osp
8
  import importlib
9
  import itertools
10
+ from typing import Union, List, Dict
11
 
12
  from ..config import Config
13
  from ..globals import Global
 
38
  raise ValueError(f"Can't read {file_path}")
39
 
40
  if ext == ".py":
41
+ importlib_util = importlib.util # type: ignore
42
+ template_module_spec = importlib_util.spec_from_file_location(
43
  "template_module", file_path)
44
+ template_module = importlib_util.module_from_spec(
45
  template_module_spec)
46
  template_module_spec.loader.exec_module(template_module)
47
  self.template_module = template_module
 
68
 
69
  def generate_prompt(
70
  self,
71
+ variables: Union[Dict[str, str], List[Union[None, str]]] = [],
72
  # instruction: str,
73
  # input: Union[None, str] = None,
74
  label: Union[None, str] = None,
 
76
  if self.template_name == "None":
77
  if type(variables) == list:
78
  res = get_val(variables, 0, "")
79
+ elif type(variables) == dict:
80
  res = variables.get("prompt", "")
81
+ else:
82
+ raise ValueError(f"Invalid variables type: {type(variables)}")
83
  elif "variables" in self.template:
84
  variable_names = self.template.get("variables")
85
+ # if type(variable_names) != list:
86
+ # raise ValueError(f"Invalid variable_names type {type(variable_names)} defined in template {self.template_name}, expecting list.")
87
  if self.template_module:
88
  if type(variables) == list:
89
  variables = {k: v for k, v in zip(