yjf9966 commited on
Commit
296d86a
1 Parent(s): c099cdf

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +61 -28
README.md CHANGED
@@ -57,35 +57,68 @@ Users (both direct and downstream) should be made aware of the risks, biases and
57
 
58
  Use the code below to get started with the model.
59
 
60
- ```
 
61
  import torch
62
- import transformers
63
- from transformers import LlamaTokenizer, LlamaForCausalLM
64
-
65
- def generate_prompt(text):
66
- return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n" +
67
- ### Instruction:\n\n{text}\n\n### Response:\n\n"""
68
-
69
- tokenizer = LlamaTokenizer.from_pretrained('BlueWhaleX/bwx-13B-HF')
70
- model = LlamaForCausalLM.from_pretrained('BlueWhaleX/bwx-13B-HF').half().cuda()
71
- model.eval()
72
-
73
- text = '王国维说:“自周之衰,文王、周公势力之瓦解也,国民之智力成熟于内,政治之纷乱乘之于外,上无统一之制度,下迫于社会之要求,于是诸于九流各创其学说。” 他意在说明 A. 分封制的崩溃 B. 商鞅变法的作用 C. 兼并战争的后果 D. 百家争鸣的原因'
74
- prompt = generate_prompt(text)
75
- input_ids = tokenizer.encode(prompt, return_tensors='pt').to('cuda')
76
-
77
- with torch.no_grad():
78
- output_ids = model.generate(
79
- input_ids=input_ids,
80
- max_new_tokens=400,
81
- temperature=0.2,
82
- top_k=40,
83
- top_p=0.9,
84
- repetition_penalty=1.3
85
- ).cuda()
86
- output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
87
- response = output.split("### Response:")[1].strip()
88
- print("Response: ", response, '\n')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  ```
90
 
91
 
 
57
 
58
  Use the code below to get started with the model.
59
 
60
+ ```python
61
+ from transformers import LlamaForCausalLM, LlamaTokenizer
62
  import torch
63
+
64
+ base_model_name = "BlueWhaleX/bwx-13B-hf"
65
+ load_type = torch.float16
66
+ device = None
67
+
68
+ generation_config = dict(
69
+ temperature=0.2,
70
+ top_k=40,
71
+ top_p=0.9,
72
+ do_sample=True,
73
+ num_beams=1,
74
+ repetition_penalty=1.3,
75
+ max_new_tokens=400
76
+ )
77
+
78
+ prompt_input = (
79
+ "Below is an instruction that describes a task. "
80
+ "Write a response that appropriately completes the request.\n\n"
81
+ "### Instruction:\n\n{instruction}\n\n### Response:\n\n"
82
+ )
83
+ if torch.cuda.is_available():
84
+ device = torch.device(0)
85
+ else:
86
+ device = torch.device('cpu')
87
+
88
+ def generate_prompt(instruction, input=None):
89
+ if input:
90
+ instruction = instruction + '\n' + input
91
+ return prompt_input.format_map({'instruction': instruction})
92
+
93
+ tokenizer = LlamaTokenizer.from_pretrained(base_model_name)
94
+ model = LlamaForCausalLM.from_pretrained(
95
+ base_model_name,
96
+ load_in_8bit=False,
97
+ torch_dtype=load_type,
98
+ low_cpu_mem_usage=True,
99
+ device_map='auto',
100
+ )
101
+
102
+ model_vocab_size = model.get_input_embeddings().weight.size(0)
103
+ tokenzier_vocab_size = len(tokenizer)
104
+ if model_vocab_size != tokenzier_vocab_size:
105
+ model.resize_token_embeddings(tokenzier_vocab_size)
106
+
107
+ raw_input_text = input("Input:")
108
+ input_text = generate_prompt(instruction=raw_input_text)
109
+ inputs = tokenizer(input_text, return_tensors="pt")
110
+ generation_output = model.generate(
111
+ input_ids=inputs["input_ids"].to(device),
112
+ attention_mask=inputs['attention_mask'].to(device),
113
+ eos_token_id=tokenizer.eos_token_id,
114
+ pad_token_id=tokenizer.pad_token_id,
115
+ **generation_config
116
+ )
117
+ s = generation_output[0]
118
+ output = tokenizer.decode(s, skip_special_tokens=True)
119
+ response = output.split("### Response:")[1].strip()
120
+ print("Response: ", response)
121
+ print("\n")
122
  ```
123
 
124