abhisheksan commited on
Commit
f12e1f2
·
verified ·
1 Parent(s): 2505881

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -21
app.py CHANGED
@@ -1,26 +1,58 @@
1
  import gradio as gr
 
 
 
2
 
3
- # Load the model directly using Gradio's load method
4
- interface = gr.load("models/meta-llama/Llama-3.2-3B-Instruct")
5
 
6
- # Define the function to call the model with additional parameters
7
- def custom_chat_fn(prompt, max_length=100, temperature=0.7):
8
- # Call the Gradio-loaded model with prompt and parameters
9
- response = interface(prompt, max_length=max_length, temperature=temperature)
10
- return response[0]['generated_text']
11
 
12
- # Create a new interface with parameter inputs using the updated Gradio components
13
- parameterized_interface = gr.Interface(
14
- fn=custom_chat_fn,
15
- inputs=[
16
- gr.Textbox(label="Prompt"),
17
- gr.Slider(50, 500, step=10, label="Max Length", value=100),
18
- gr.Slider(0.1, 1.0, step=0.1, label="Temperature", value=0.7)
19
- ],
20
- outputs="text",
21
- live=True,
22
- allow_flagging='never',
23
- description="Enter a prompt, set the max length and temperature, then click 'Submit' to generate text."
24
- )
25
 
26
- parameterized_interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from fastapi import FastAPI
3
+ from pydantic import BaseModel, Field
4
+ from typing import Optional
5
 
6
+ # Initialize FastAPI app
7
+ app = FastAPI()
8
 
9
+ # Load the model once at startup
10
+ model = gr.load("models/meta-llama/Llama-3.2-3B-Instruct")
 
 
 
11
 
12
+ class PoemRequest(BaseModel):
13
+ prompt: str = Field(..., description="The prompt for poem generation")
14
+ temperature: Optional[float] = Field(0.7, ge=0.1, le=2.0, description="Controls randomness in generation")
15
+ top_p: Optional[float] = Field(0.9, ge=0.1, le=1.0, description="Nucleus sampling parameter")
16
+ top_k: Optional[int] = Field(50, ge=1, le=100, description="Top-k sampling parameter")
17
+ max_length: Optional[int] = Field(200, ge=50, le=500, description="Maximum length of generated text")
18
+ repetition_penalty: Optional[float] = Field(1.1, ge=1.0, le=2.0, description="Penalty for repetition")
 
 
 
 
 
 
19
 
20
+ class PoemResponse(BaseModel):
21
+ poem: str
22
+ parameters_used: dict
23
+
24
+ @app.post("/generate_poem")
25
+ async def generate_poem(request: PoemRequest) -> PoemResponse:
26
+ """
27
+ Generate a poem based on the provided prompt and parameters.
28
+
29
+ Returns:
30
+ PoemResponse: Contains the generated poem and the parameters used
31
+ """
32
+ try:
33
+ # Prepare generation parameters
34
+ generation_config = {
35
+ "temperature": request.temperature,
36
+ "top_p": request.top_p,
37
+ "top_k": request.top_k,
38
+ "max_length": request.max_length,
39
+ "repetition_penalty": request.repetition_penalty,
40
+ }
41
+
42
+ # Generate the poem
43
+ response = model(
44
+ request.prompt,
45
+ **generation_config
46
+ )
47
+
48
+ return PoemResponse(
49
+ poem=response,
50
+ parameters_used=generation_config
51
+ )
52
+
53
+ except Exception as e:
54
+ raise HTTPException(status_code=500, detail=str(e))
55
+
56
+ if __name__ == "__main__":
57
+ import uvicorn
58
+ uvicorn.run(app, host="0.0.0.0", port=8000)