teatwots commited on
Commit
700ece2
1 Parent(s): 4f72f15

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -21
app.py CHANGED
@@ -2,49 +2,79 @@
2
  from transformers import T5Tokenizer, T5ForConditionalGeneration
3
  import gradio as gr
4
  import nltk
5
- from nltk.tokenize import sent_tokenize
6
- import difflib
 
7
 
8
- # Download the punkt tokenizer for sentence splitting
9
  nltk.download('punkt')
 
 
10
 
11
  # Load a pre-trained T5 model specifically fine-tuned for grammar correction
12
  tokenizer = T5Tokenizer.from_pretrained("prithivida/grammar_error_correcter_v1")
13
  model = T5ForConditionalGeneration.from_pretrained("prithivida/grammar_error_correcter_v1")
14
 
15
- # Function to perform grammar correction
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def grammar_check(text):
17
- # Split the text into sentences
18
  sentences = sent_tokenize(text)
19
  corrected_sentences = []
20
- original_sentences = []
 
21
 
22
  for sentence in sentences:
23
- original_sentences.append(sentence)
24
  input_text = f"gec: {sentence}"
25
  input_ids = tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True)
26
  outputs = model.generate(input_ids, max_length=512, num_beams=4, early_stopping=True)
27
  corrected_sentence = tokenizer.decode(outputs[0], skip_special_tokens=True)
28
  corrected_sentences.append(corrected_sentence)
 
29
 
30
  # Function to underline and color revised parts
31
  def underline_and_color_revisions(original, corrected):
32
- diff = difflib.ndiff(original.split(), corrected.split())
33
  result = []
34
- for word in diff:
35
- if word.startswith("+ "):
36
- result.append(f"<u style='color:red;'>{word[2:]}</u>")
37
- elif word.startswith("- "):
38
- continue
39
- else:
40
- result.append(word[2:])
41
  return " ".join(result)
42
 
43
- # Join the corrected sentences back into a single string
44
  corrected_text = " ".join(
45
- underline_and_color_revisions(orig, corr) for orig, corr in zip(original_sentences, corrected_sentences)
46
  )
47
- return corrected_text
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  # Create Gradio interface with a writing prompt
50
  interface = gr.Interface(
@@ -57,10 +87,10 @@ interface = gr.Interface(
57
  "Writing Prompt:\n"
58
  "In the story, Alex and his friends discovered an ancient treasure in Whispering Hollow and decided to donate the artifacts to the local museum.\n\n"
59
  "In the past, did you have a similar experience where you found something valuable or interesting? Tell the story. Describe what you found, what you did with it, and how you felt about your decision.\n\n"
60
- "Remember to use past tense in your writing.\n"
61
- "Sample text for testing: When I was 10, I find an old coin in my backyard. I kept it for a while and shows it to my friends. They was impressed and say it might be valuable. Later, I take it to a local antique shop, and the owner told me it was very old. I decided to give it to the museum in my town. The museum was happy and put it on display. I feel proud of my decision."
62
  )
63
  )
64
 
65
  # Launch the interface
66
- interface.launch()
 
 
2
  from transformers import T5Tokenizer, T5ForConditionalGeneration
3
  import gradio as gr
4
  import nltk
5
+ from nltk.tokenize import sent_tokenize, word_tokenize
6
+ from nltk.corpus import wordnet as wn
7
+ from difflib import SequenceMatcher
8
 
9
+ # Download necessary resources
10
  nltk.download('punkt')
11
+ nltk.download('averaged_perceptron_tagger')
12
+ nltk.download('wordnet')
13
 
14
  # Load a pre-trained T5 model specifically fine-tuned for grammar correction
15
  tokenizer = T5Tokenizer.from_pretrained("prithivida/grammar_error_correcter_v1")
16
  model = T5ForConditionalGeneration.from_pretrained("prithivida/grammar_error_correcter_v1")
17
 
18
+ # Function to get the base form (lemma) of verbs
19
+ def get_base_form(word, tag):
20
+ wn_tag = {'VBD': wn.VERB, 'VBG': wn.VERB, 'VBN': wn.VERB, 'VBP': wn.VERB, 'VBZ': wn.VERB, 'VB': wn.VERB}
21
+ if tag in wn_tag:
22
+ lemma = nltk.WordNetLemmatizer().lemmatize(word, wn_tag[tag])
23
+ return lemma
24
+ return word
25
+
26
+ # Function to extract verbs from a sentence
27
+ def extract_verbs(sentence):
28
+ words = word_tokenize(sentence)
29
+ tagged = nltk.pos_tag(words)
30
+ verbs = [(word, tag) for word, tag in tagged if tag.startswith('VB')]
31
+ return verbs
32
+
33
+ # Function to perform grammar correction and generate verb forms list
34
  def grammar_check(text):
 
35
  sentences = sent_tokenize(text)
36
  corrected_sentences = []
37
+ original_verbs = []
38
+ corrected_verbs = []
39
 
40
  for sentence in sentences:
41
+ original_verbs.extend(extract_verbs(sentence))
42
  input_text = f"gec: {sentence}"
43
  input_ids = tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True)
44
  outputs = model.generate(input_ids, max_length=512, num_beams=4, early_stopping=True)
45
  corrected_sentence = tokenizer.decode(outputs[0], skip_special_tokens=True)
46
  corrected_sentences.append(corrected_sentence)
47
+ corrected_verbs.extend(extract_verbs(corrected_sentence))
48
 
49
  # Function to underline and color revised parts
50
  def underline_and_color_revisions(original, corrected):
51
+ diff = SequenceMatcher(None, original.split(), corrected.split())
52
  result = []
53
+ for tag, i1, i2, j1, j2 in diff.get_opcodes():
54
+ if tag == 'insert':
55
+ result.append(f"<u style='color:red;'>{' '.join(corrected.split()[j1:j2])}</u>")
56
+ elif tag == 'replace':
57
+ result.append(f"<u style='color:red;'>{' '.join(corrected.split()[j1:j2])}</u>")
58
+ elif tag == 'equal':
59
+ result.append(' '.join(original.split()[i1:i2]))
60
  return " ".join(result)
61
 
 
62
  corrected_text = " ".join(
63
+ underline_and_color_revisions(orig, corr) for orig, corr in zip(sentences, corrected_sentences)
64
  )
65
+
66
+ # Generate verb forms list
67
+ verb_forms_list = []
68
+ for orig, corr in zip(original_verbs, corrected_verbs):
69
+ base_orig = get_base_form(orig[0], orig[1])
70
+ base_corr = get_base_form(corr[0], corr[1])
71
+ if base_orig != base_corr:
72
+ verb_forms_list.append(f"{base_orig}-{corr[0]}-{base_corr}")
73
+
74
+ verb_forms_str = "\n".join(verb_forms_list)
75
+
76
+ # Return combined result
77
+ return f"{corrected_text}\n\n<b>Revised Verb Forms:</b>\n{verb_forms_str}"
78
 
79
  # Create Gradio interface with a writing prompt
80
  interface = gr.Interface(
 
87
  "Writing Prompt:\n"
88
  "In the story, Alex and his friends discovered an ancient treasure in Whispering Hollow and decided to donate the artifacts to the local museum.\n\n"
89
  "In the past, did you have a similar experience where you found something valuable or interesting? Tell the story. Describe what you found, what you did with it, and how you felt about your decision.\n\n"
90
+ "Remember to use past tense in your writing."
 
91
  )
92
  )
93
 
94
  # Launch the interface
95
+ interface.launch()
96
+