liujch1998 commited on
Commit
11a61b4
β€’
1 Parent(s): 36a1394

Add debug mode

Browse files
Files changed (1) hide show
  1. app.py +16 -8
app.py CHANGED
@@ -9,8 +9,13 @@ import shutil
9
 
10
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
11
 
 
 
 
 
12
  HF_TOKEN_DOWNLOAD = os.environ['HF_TOKEN_DOWNLOAD']
13
  HF_TOKEN_UPLOAD = os.environ['HF_TOKEN_UPLOAD']
 
14
 
15
  MODEL_NAME = 'liujch1998/cd-pi'
16
  DATASET_REPO_URL = "https://huggingface.co/datasets/liujch1998/cd-pi-dataset"
@@ -32,6 +37,8 @@ repo.git_pull()
32
  class Interactive:
33
  def __init__(self):
34
  self.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN_DOWNLOAD)
 
 
35
  self.model = transformers.T5EncoderModel.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN_DOWNLOAD, low_cpu_mem_usage=True, device_map='auto', torch_dtype='auto')
36
  self.linear = torch.nn.Linear(self.model.shared.embedding_dim, 1, dtype=self.model.dtype).to(device)
37
  self.linear.weight = torch.nn.Parameter(self.model.shared.weight[32099, :].unsqueeze(0)) # (1, D)
@@ -40,6 +47,13 @@ class Interactive:
40
  self.t = self.model.shared.weight[32097, 0].item()
41
 
42
  def run(self, statement):
 
 
 
 
 
 
 
43
  input_ids = self.tokenizer.batch_encode_plus([statement], return_tensors='pt', padding='longest').input_ids.to(device)
44
  with torch.no_grad():
45
  output = self.model(input_ids)
@@ -55,16 +69,10 @@ class Interactive:
55
  'score': score.item(),
56
  'score_calibrated': score_calibrated.item(),
57
  }
58
- # return {
59
- # 'logit': 0.0,
60
- # 'logit_calibrated': 0.0,
61
- # 'score': 0.5,
62
- # 'score_calibrated': 0.5,
63
- # }
64
 
65
  interactive = Interactive()
66
 
67
- def predict(statement):
68
  result = interactive.run(statement)
69
  with open(DATA_PATH, 'a') as f:
70
  row = {
@@ -155,7 +163,7 @@ description = '''This is a demo for a commonsense statement verification model.
155
 
156
  gr.Interface(
157
  fn=predict,
158
- inputs=[input_statement],
159
  outputs=output,
160
  title="cd-pi Demo",
161
  description=description,
 
9
 
10
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
11
 
12
+ # To suppress the following warning:
13
+ # huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
14
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
15
+
16
  HF_TOKEN_DOWNLOAD = os.environ['HF_TOKEN_DOWNLOAD']
17
  HF_TOKEN_UPLOAD = os.environ['HF_TOKEN_UPLOAD']
18
+ MODE = os.environ['MODE']
19
 
20
  MODEL_NAME = 'liujch1998/cd-pi'
21
  DATASET_REPO_URL = "https://huggingface.co/datasets/liujch1998/cd-pi-dataset"
 
37
  class Interactive:
38
  def __init__(self):
39
  self.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN_DOWNLOAD)
40
+ if MODE == 'debug':
41
+ return
42
  self.model = transformers.T5EncoderModel.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN_DOWNLOAD, low_cpu_mem_usage=True, device_map='auto', torch_dtype='auto')
43
  self.linear = torch.nn.Linear(self.model.shared.embedding_dim, 1, dtype=self.model.dtype).to(device)
44
  self.linear.weight = torch.nn.Parameter(self.model.shared.weight[32099, :].unsqueeze(0)) # (1, D)
 
47
  self.t = self.model.shared.weight[32097, 0].item()
48
 
49
  def run(self, statement):
50
+ if MODE == 'debug':
51
+ return {
52
+ 'logit': 0.0,
53
+ 'logit_calibrated': 0.0,
54
+ 'score': 0.5,
55
+ 'score_calibrated': 0.5,
56
+ }
57
  input_ids = self.tokenizer.batch_encode_plus([statement], return_tensors='pt', padding='longest').input_ids.to(device)
58
  with torch.no_grad():
59
  output = self.model(input_ids)
 
69
  'score': score.item(),
70
  'score_calibrated': score_calibrated.item(),
71
  }
 
 
 
 
 
 
72
 
73
  interactive = Interactive()
74
 
75
+ def predict(statement, model):
76
  result = interactive.run(statement)
77
  with open(DATA_PATH, 'a') as f:
78
  row = {
 
163
 
164
  gr.Interface(
165
  fn=predict,
166
+ inputs=[input_statement, input_model],
167
  outputs=output,
168
  title="cd-pi Demo",
169
  description=description,