Tonic commited on
Commit
9641cfa
·
verified ·
1 Parent(s): 8ff5503

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -36
app.py CHANGED
@@ -25,43 +25,43 @@ os.system('python -m spacy download en_core_web_sm')
25
  nlp = spacy.load("en_core_web_sm")
26
 
27
  def historical_generation(prompt, max_new_tokens=600, top_k=50, temperature=0.7, top_p=0.95, repetition_penalty=1.0):
28
- with torch.no_grad():
29
- prompt = f"### Text ###\n{prompt}"
30
- inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024)
31
- input_ids = inputs["input_ids"].to(device)
32
- attention_mask = inputs["attention_mask"].to(device)
33
 
34
- output = model.generate(
35
- input_ids,
36
- attention_mask=attention_mask,
37
- max_new_tokens=max_new_tokens,
38
- pad_token_id=tokenizer.eos_token_id,
39
- top_k=top_k,
40
- temperature=temperature,
41
- top_p=top_p,
42
- do_sample=True,
43
- repetition_penalty=repetition_penalty,
44
- bos_token_id=tokenizer.bos_token_id,
45
- eos_token_id=tokenizer.eos_token_id
46
- )
47
-
48
- generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
49
-
50
- if "### Correction ###" in generated_text:
51
- generated_text = generated_text.split("### Correction ###")[1].strip()
52
-
53
- tokens = tokenizer.tokenize(generated_text)
54
-
55
- highlighted_text = []
56
- for token in tokens:
57
- clean_token = token.replace("Ġ", "")
58
- token_type = tokenizer.convert_ids_to_tokens([tokenizer.convert_tokens_to_ids(token)])[0].replace("Ġ", "")
59
- highlighted_text.append((clean_token, token_type))
60
-
61
- del inputs, input_ids, attention_mask, output, tokens
62
- torch.cuda.empty_cache()
63
-
64
- return highlighted_text, generated_text
65
 
66
  def text_analysis(text):
67
  doc = nlp(text)
 
25
  nlp = spacy.load("en_core_web_sm")
26
 
27
  def historical_generation(prompt, max_new_tokens=600, top_k=50, temperature=0.7, top_p=0.95, repetition_penalty=1.0):
28
+ # with torch.no_grad():
29
+ prompt = f"### Text ###\n{prompt}"
30
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024)
31
+ input_ids = inputs["input_ids"].to(device)
32
+ attention_mask = inputs["attention_mask"].to(device)
33
 
34
+ output = model.generate(
35
+ input_ids,
36
+ attention_mask=attention_mask,
37
+ max_new_tokens=max_new_tokens,
38
+ pad_token_id=tokenizer.eos_token_id,
39
+ top_k=top_k,
40
+ temperature=temperature,
41
+ top_p=top_p,
42
+ do_sample=True,
43
+ repetition_penalty=repetition_penalty,
44
+ bos_token_id=tokenizer.bos_token_id,
45
+ eos_token_id=tokenizer.eos_token_id
46
+ )
47
+
48
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
49
+
50
+ if "### Correction ###" in generated_text:
51
+ generated_text = generated_text.split("### Correction ###")[1].strip()
52
+
53
+ tokens = tokenizer.tokenize(generated_text)
54
+
55
+ highlighted_text = []
56
+ for token in tokens:
57
+ clean_token = token.replace("Ġ", "")
58
+ token_type = tokenizer.convert_ids_to_tokens([tokenizer.convert_tokens_to_ids(token)])[0].replace("Ġ", "")
59
+ highlighted_text.append((clean_token, token_type))
60
+
61
+ del inputs, input_ids, attention_mask, output, tokens
62
+ torch.cuda.empty_cache()
63
+
64
+ return highlighted_text, generated_text
65
 
66
  def text_analysis(text):
67
  doc = nlp(text)