autonomous019 commited on
Commit
b0a4b77
1 Parent(s): b676b33

adding story func

Browse files
Files changed (1) hide show
  1. app.py +47 -1
app.py CHANGED
@@ -13,6 +13,21 @@ import torch
13
  # https://github.com/NielsRogge/Transformers-Tutorials/blob/master/HuggingFace_vision_ecosystem_overview_(June_2022).ipynb
14
  # option 1: load with randomly initialized weights (train from scratch)
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  config = ViTConfig(num_hidden_layers=12, hidden_size=768)
17
  model = ViTForImageClassification(config)
18
 
@@ -28,6 +43,36 @@ model = PerceiverForImageClassificationConvProcessing.from_pretrained("deepmind/
28
  #google/vit-base-patch16-224, deepmind/vision-perceiver-conv
29
  image_pipe = ImageClassificationPipeline(model=model, feature_extractor=feature_extractor)
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  def self_caption(image):
33
  repo_name = "ydshieh/vit-gpt2-coco-en"
@@ -60,7 +105,8 @@ def self_caption(image):
60
  print(pred_dictionary)
61
  #return(pred_dictionary)
62
  preds = ' '.join(preds)
63
- return preds
 
64
 
65
 
66
  def classify_image(image):
 
13
  # https://github.com/NielsRogge/Transformers-Tutorials/blob/master/HuggingFace_vision_ecosystem_overview_(June_2022).ipynb
14
  # option 1: load with randomly initialized weights (train from scratch)
15
 
16
+ from transformers import (
17
+ AutoModelForCausalLM,
18
+ LogitsProcessorList,
19
+ MinLengthLogitsProcessor,
20
+ StoppingCriteriaList,
21
+ MaxLengthCriteria,
22
+ )
23
+
24
+
25
+
26
+
27
+
28
+
29
+
30
+
31
  config = ViTConfig(num_hidden_layers=12, hidden_size=768)
32
  model = ViTForImageClassification(config)
33
 
 
43
  #google/vit-base-patch16-224, deepmind/vision-perceiver-conv
44
  image_pipe = ImageClassificationPipeline(model=model, feature_extractor=feature_extractor)
45
 
46
+ def create_story(text_seed):
47
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
48
+ model = AutoModelForCausalLM.from_pretrained("gpt2")
49
+
50
+ # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
51
+ model.config.pad_token_id = model.config.eos_token_id
52
+
53
+ #input_prompt = "It might be possible to"
54
+ input_prompt = text_seed
55
+ input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
56
+
57
+ # instantiate logits processors
58
+ logits_processor = LogitsProcessorList(
59
+ [
60
+ MinLengthLogitsProcessor(10, eos_token_id=model.config.eos_token_id),
61
+ ]
62
+ )
63
+ stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
64
+
65
+ outputs = model.greedy_search(
66
+ input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria
67
+ )
68
+
69
+ result_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
70
+ return result_text
71
+
72
+
73
+
74
+
75
+
76
 
77
  def self_caption(image):
78
  repo_name = "ydshieh/vit-gpt2-coco-en"
 
105
  print(pred_dictionary)
106
  #return(pred_dictionary)
107
  preds = ' '.join(preds)
108
+ story = create_story(preds)
109
+ return story
110
 
111
 
112
  def classify_image(image):