dwb2023 commited on
Commit
86fbb40
1 Parent(s): 8e01118

Update app.py

Browse files

add CUDA config

Files changed (1) hide show
  1. app.py +9 -1
app.py CHANGED
@@ -1,10 +1,18 @@
1
  import gradio as gr
 
2
  from transformers import AutoModel
3
 
4
  @spaces.GPU
5
  def get_model_summary(model_name):
6
- model = AutoModel.from_pretrained(model_name)
 
 
 
 
 
 
7
  return str(model)
8
 
 
9
  interface = gr.Interface(fn=get_model_summary, inputs="text", outputs="text")
10
  interface.launch()
 
1
  import gradio as gr
2
+ import torch
3
  from transformers import AutoModel
4
 
5
  @spaces.GPU
6
  def get_model_summary(model_name):
7
+ # Check if CUDA is available and set the device accordingly
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
+ # Load the model and move it to the selected device
11
+ model = AutoModel.from_pretrained(model_name).to(device)
12
+
13
+ # Return the model's architecture as a string
14
  return str(model)
15
 
16
+ # Create the Gradio interface
17
  interface = gr.Interface(fn=get_model_summary, inputs="text", outputs="text")
18
  interface.launch()