Text2Text Generation
Transformers
PyTorch
t5
codet5
text-generation-inference
nielsr HF staff commited on
Commit
f23c18c
1 Parent(s): 1a501f4

Update code example

Browse files
Files changed (1) hide show
  1. README.md +4 -48
README.md CHANGED
@@ -36,7 +36,7 @@ See the [model hub](https://huggingface.co/models?search=salesforce/codet) to lo
36
 
37
  ### How to use
38
 
39
- Here is how to use this model for masked span prediction:
40
 
41
  ```python
42
  from transformers import RobertaTokenizer, T5ForConditionalGeneration
@@ -44,57 +44,13 @@ from transformers import RobertaTokenizer, T5ForConditionalGeneration
44
  tokenizer = RobertaTokenizer.from_pretrained('Salesforce/codet5-base')
45
  model = T5ForConditionalGeneration.from_pretrained('Salesforce/codet5-base')
46
 
47
- text = "def greet(user): print(f'hello <extra_id_0>!') </s>"
48
- inputs = tokenizer(text, return_tensors="pt").input_ids
49
 
50
  # simply generate a single sequence
51
  generated_ids = model.generate(input_ids, max_length=8)
52
  print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
53
- # this prints {user.name}
54
-
55
- # or, generating 20 sequences with maximum length set to 10
56
- outputs = model.generate(input_ids=input_ids,
57
- num_beams=200, num_return_sequences=20,
58
- max_length=10)
59
-
60
- _0_index = text.index('<extra_id_0>')
61
- _result_prefix = text[:_0_index]
62
- _result_suffix = text[_0_index+12:] # 12 is the length of <extra_id_0>
63
-
64
- def _filter(output, end_token='<extra_id_1>'):
65
- # The first token is <pad> (indexed at 0), the second token is <s> (indexed at 1)
66
- # and the third token is <extra_id_0> (indexed at 32099)
67
- # So we only decode from the fourth generated id
68
- _txt = tokenizer.decode(output[3:], skip_special_tokens=False, clean_up_tokenization_spaces=False)
69
- if end_token in _txt:
70
- _end_token_index = _txt.index(end_token)
71
- return _result_prefix + _txt[:_end_token_index] + _result_suffix
72
- else:
73
- return _result_prefix + _txt + _result_suffix
74
-
75
- results = list(map(_filter, outputs))
76
- print(results)
77
- # this prints:
78
- #["def greet(user): print(f'hello {user.name} {user!') </s>",
79
- # "def greet(user): print(f'hello {user.username} {user!') </s>",
80
- # "def greet(user): print(f'hello {user.name}: {user!') </s>",
81
- # "def greet(user): print(f'hello {user}') print(f!') </s>",
82
- # "def greet(user): print(f'hello {user.name} �!') </s>",
83
- # "def greet(user): print(f'hello {user}') print ( f!') </s>",
84
- # "def greet(user): print(f'hello {user.username}: {user!') </s>",
85
- # "def greet(user): print(f'hello {user}' ) print(f!') </s>",
86
- # "def greet(user): print(f'hello {user.username} �!') </s>",
87
- # "def greet(user): print(f'hello {user.name}, {user!') </s>",
88
- # "def greet(user): print(f'hello {user.login} {user!') </s>",
89
- # "def greet(user): print(f'hello {user} →!') </s>",
90
- # "def greet(user): print(f'hello {user}!') print(!') </s>",
91
- # "def greet(user): print(f'hello {user.name} ({user!') </s>",
92
- # "def greet(user): print(f'hello {user.email} {user!') </s>",
93
- # "def greet(user): print(f'hello {user}!') print (!') </s>",
94
- # "def greet(user): print(f'hello {user.username}, {user!') </s>",
95
- # "def greet(user): print(f'hello {user}' ) print ( f!') </s>",
96
- # "def greet(user): print(f'hello {user.nickname} {!') </s>",
97
- # "def greet(user): print(f'hello {user} {user.name!') </s>"]
98
  ```
99
 
100
  ## Training data
 
36
 
37
  ### How to use
38
 
39
+ Here is how to use this model:
40
 
41
  ```python
42
  from transformers import RobertaTokenizer, T5ForConditionalGeneration
 
44
  tokenizer = RobertaTokenizer.from_pretrained('Salesforce/codet5-base')
45
  model = T5ForConditionalGeneration.from_pretrained('Salesforce/codet5-base')
46
 
47
+ text = "def greet(user): print(f'hello <extra_id_0>!')"
48
+ input_ids = tokenizer(text, return_tensors="pt").input_ids
49
 
50
  # simply generate a single sequence
51
  generated_ids = model.generate(input_ids, max_length=8)
52
  print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
53
+ # this prints "{user.username}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  ```
55
 
56
  ## Training data