autonomous019 commited on
Commit
173b282
1 Parent(s): 79e0b51

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -0
app.py CHANGED
@@ -16,6 +16,26 @@ from transformers import (
16
  StoppingCriteriaList,
17
  MaxLengthCriteria,
18
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  # https://github.com/NielsRogge/Transformers-Tutorials/blob/master/HuggingFace_vision_ecosystem_overview_(June_2022).ipynb
21
  # option 1: load with randomly initialized weights (train from scratch)
@@ -39,6 +59,52 @@ model = PerceiverForImageClassificationConvProcessing.from_pretrained("deepmind/
39
  #google/vit-base-patch16-224, deepmind/vision-perceiver-conv
40
  image_pipe = ImageClassificationPipeline(model=model, feature_extractor=feature_extractor)
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  def create_story(text_seed):
43
  #tokenizer = AutoTokenizer.from_pretrained("gpt2")
44
  #model = AutoModelForCausalLM.from_pretrained("gpt2")
 
16
  StoppingCriteriaList,
17
  MaxLengthCriteria,
18
  )
19
+ import json
20
+ import os
21
+ from screenshot import (
22
+ before_prompt,
23
+ prompt_to_generation,
24
+ after_generation,
25
+ js_save,
26
+ js_load_script,
27
+ )
28
+ from spaces_info import description, examples, initial_prompt_value
29
+
30
+ API_URL = os.getenv("API_URL")
31
+ HF_API_TOKEN = os.getenv("HF_API_TOKEN")
32
+
33
+
34
+ def query(payload):
35
+ print(payload)
36
+ response = requests.request("POST", API_URL, json=payload, headers={"Authorization": f"Bearer {HF_API_TOKEN}"})
37
+ print(response)
38
+ return json.loads(response.content.decode("utf-8"))
39
 
40
  # https://github.com/NielsRogge/Transformers-Tutorials/blob/master/HuggingFace_vision_ecosystem_overview_(June_2022).ipynb
41
  # option 1: load with randomly initialized weights (train from scratch)
 
59
  #google/vit-base-patch16-224, deepmind/vision-perceiver-conv
60
  image_pipe = ImageClassificationPipeline(model=model, feature_extractor=feature_extractor)
61
 
62
+
63
+
64
+
65
+ def inference(input_sentence, max_length, sample_or_greedy, seed=42):
66
+ if sample_or_greedy == "Sample":
67
+ parameters = {
68
+ "max_new_tokens": max_length,
69
+ "top_p": 0.9,
70
+ "do_sample": True,
71
+ "seed": seed,
72
+ "early_stopping": False,
73
+ "length_penalty": 0.0,
74
+ "eos_token_id": None,
75
+ }
76
+ else:
77
+ parameters = {
78
+ "max_new_tokens": max_length,
79
+ "do_sample": False,
80
+ "seed": seed,
81
+ "early_stopping": False,
82
+ "length_penalty": 0.0,
83
+ "eos_token_id": None,
84
+ }
85
+
86
+ payload = {"inputs": input_sentence, "parameters": parameters,"options" : {"use_cache": False} }
87
+
88
+ data = query(payload)
89
+
90
+ if "error" in data:
91
+ return (None, None, f"<span style='color:red'>ERROR: {data['error']} </span>")
92
+
93
+ generation = data[0]["generated_text"].split(input_sentence, 1)[1]
94
+ return (
95
+ before_prompt
96
+ + input_sentence
97
+ + prompt_to_generation
98
+ + generation
99
+ + after_generation,
100
+ data[0]["generated_text"],
101
+ "",
102
+ )
103
+
104
+
105
+
106
+
107
+
108
  def create_story(text_seed):
109
  #tokenizer = AutoTokenizer.from_pretrained("gpt2")
110
  #model = AutoModelForCausalLM.from_pretrained("gpt2")