lodestones commited on
Commit
42fd79d
1 Parent(s): 441708b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +18 -0
README.md CHANGED
@@ -70,6 +70,24 @@ print(tokenizer.decode(outputs[0]))
70
  ```
71
 
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  #### Running the model on a single / multi GPU
74
 
75
 
 
70
  ```
71
 
72
 
73
+ #### Running the model using Flax on a single GPU / TPU
74
+
75
+
76
+ ```python
77
+ from transformers import AutoTokenizer, FlaxGemmaForCausalLM
78
+
79
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
80
+ model = FlaxGemmaForCausalLM.from_pretrained("google/gemma-7b-flax")
81
+ model.params = jax.tree_map(lambda x: jax.device_put(x, jax.devices()[0]).astype(jnp.float16), flax.params)
82
+
83
+ input_text = "Write me a poem about Machine Learning."
84
+ input_ids = jax.device_put(tokenizer(input_text, return_tensors="jax"), jax.devices()[0])
85
+
86
+ outputs = model.generate(**input_ids)
87
+ print(tokenizer.decode(outputs[0][0]))
88
+ ```
89
+
90
+
91
  #### Running the model on a single / multi GPU
92
 
93