akshaytrikha commited on
Commit
8cce078
1 Parent(s): 8748247

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -3
app.py CHANGED
@@ -1,10 +1,20 @@
1
  import gradio as gr
2
- from transformers import pipeline
3
 
4
- generator = pipeline('text-generation', model='pytorch_model.bin')
 
 
 
 
 
 
 
 
 
 
5
 
6
  def generate(text):
7
- result = generator(text, max_length=30, num_return_sequences=1)
8
  return result[0]["generated_text"]
9
 
10
  examples = [
 
1
  import gradio as gr
2
+ from transformers import GPT2LMHeadModel, pipeline
3
 
4
+ device = 'cuda' if args.cuda else 'cpu'
5
+
6
+ # load pretrained + finetuned GPT2
7
+ model = GPT2LMHeadModel.from_pretrained("./model/pytorch_model.bin", from_pt=True)
8
+ # model = GPT2LMHeadModel.from_pretrained("/zxc/model_epoch40_50w")
9
+ model = model.to(device)
10
+
11
+
12
+ # generator = pipeline('text-generation', model=model)
13
+
14
+ trump = pipeline("text-generation", model=model, tokenizer=tokenizer, config={"max_length":140})
15
 
16
  def generate(text):
17
+ result = trump(text, num_return_sequences=1)
18
  return result[0]["generated_text"]
19
 
20
  examples = [