xumingyu16 commited on
Commit
e1288db
·
verified ·
1 Parent(s): 28d4b82

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +14 -31
README.md CHANGED
@@ -342,37 +342,20 @@ import torch
342
  model_name = "baichuan-inc/Baichuan-M1-14B-Base"
343
  tokenizer = AutoTokenizer.from_pretrained(model_name,trust_remote_code=True)
344
  model = AutoModelForCausalLM.from_pretrained(model_name,trust_remote_code=True,torch_dtype = torch.bfloat16).cuda()
345
- # 2. Input prompt text
346
- prompt = "May I ask you some questions about medical knowledge?"
347
-
348
- # 3. Encode the input text for the model
349
- messages = [
350
- {"role": "system", "content": "You are a helpful assistant."},
351
- {"role": "user", "content": prompt}
352
- ]
353
- text = tokenizer.apply_chat_template(
354
- messages,
355
- tokenize=False,
356
- add_generation_prompt=True
357
- )
358
- model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
359
-
360
- # 4. Generate text
361
- generated_ids = model.generate(
362
- **model_inputs,
363
- max_new_tokens=512
364
- )
365
- generated_ids = [
366
- output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
367
- ]
368
-
369
- # 5. Decode the generated text
370
- response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
371
-
372
-
373
- # 6. Output the result
374
- print("Generated text:")
375
- print(response)
376
  ```
377
 
378
  ---
 
342
  model_name = "baichuan-inc/Baichuan-M1-14B-Base"
343
  tokenizer = AutoTokenizer.from_pretrained(model_name,trust_remote_code=True)
344
  model = AutoModelForCausalLM.from_pretrained(model_name,trust_remote_code=True,torch_dtype = torch.bfloat16).cuda()
345
+
346
+ input_text = "I have recently recovered from my cold."
347
+
348
+
349
+ inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
350
+
351
+ outputs = model.generate(
352
+ inputs["input_ids"],
353
+ max_length=100,
354
+ )
355
+
356
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
357
+ print("Generated Text:")
358
+ print(generated_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
  ```
360
 
361
  ---