henry2024 commited on
Commit
e1c0816
·
verified ·
1 Parent(s): 5f71b26

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -59,15 +59,15 @@ def get_model(base_model: str = "bigcode/starcoder",):
59
  return tokenizer, model
60
 
61
 
62
- class RNN_model(nn.Module):
63
  def __init__(self):
64
  super().__init__()
65
 
66
- self.rnn= nn.GRU(input_size=1080, hidden_size=240,num_layers=1, nonlinearity= 'relu', bias= True)
67
  self.output= nn.Linear(in_features=240, out_features=24)
68
 
69
  def forward(self, x):
70
- y, hidden= self.rnn(x)
71
  #print(y.shape)
72
  #print(hidden.shape)
73
  x= self.output(y)
@@ -112,7 +112,7 @@ vectorizer= nltk_u.vectorizer()
112
  vectorizer.fit(train_data.text)
113
 
114
  # Model and transforms preparation
115
- model= RNN_model()
116
  # Load state dict
117
  model.load_state_dict(torch.load(
118
  f= 'pretrained_symtom_to_disease_model.pth',
@@ -183,7 +183,7 @@ with gr.Blocks(css = """#col_container { margin-left: auto; margin-right: auto;}
183
  time.sleep(2)
184
  return "", chat_history
185
  '''
186
- def respond(message, chat_history, base_model = "microsoft/phi-2", device=device): # "meta-llama/Meta-Llama-3-70B"
187
  if base_model != "microsoft/phi-2":
188
  # Random greetings in list format
189
  greetings = [
 
59
  return tokenizer, model
60
 
61
 
62
+ class GRU_model(nn.Module):
63
  def __init__(self):
64
  super().__init__()
65
 
66
+ self.gru= nn.GRU(input_size=1080, hidden_size=240,num_layers=1, nonlinearity= 'relu', bias= True)
67
  self.output= nn.Linear(in_features=240, out_features=24)
68
 
69
  def forward(self, x):
70
+ y, hidden= self.gru(x)
71
  #print(y.shape)
72
  #print(hidden.shape)
73
  x= self.output(y)
 
112
  vectorizer.fit(train_data.text)
113
 
114
  # Model and transforms preparation
115
+ model= GRU_model()
116
  # Load state dict
117
  model.load_state_dict(torch.load(
118
  f= 'pretrained_symtom_to_disease_model.pth',
 
183
  time.sleep(2)
184
  return "", chat_history
185
  '''
186
+ def respond(message, chat_history, base_model = "self_GRU", device=device): # "meta-llama/Meta-Llama-3-70B"
187
  if base_model != "microsoft/phi-2":
188
  # Random greetings in list format
189
  greetings = [