Empereur-Pirate commited on
Commit
6f96f84
1 Parent(s): e8d01f8

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +29 -22
main.py CHANGED
@@ -1,9 +1,9 @@
1
- from fastapi import FastAPI
2
- from fastapi.responses import FileResponse
3
  from fastapi.staticfiles import StaticFiles
4
  from transformers import pipeline
5
  from pydantic import BaseModel
6
- from typing import Union
7
  import torch
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
9
 
@@ -14,24 +14,28 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
14
 
15
  pipe_flan = pipeline("text2text-generation", model="google/flan-t5-small")
16
 
17
- def t5(input):
18
  output = pipe_flan(input)
19
- return {"output": output[0]["generated_text"]}
20
 
21
  @app.post("/infer_t5")
22
- def infer_endpoint(data: dict):
23
  """Receive input and generate text."""
24
- input_text = data.get("input")
 
25
 
26
- # Validate that the input is a string
27
- assert isinstance(input_text, str), "Input must be a string."
 
 
 
28
 
29
- if input_text is None:
30
- return {"error": "No input text detected."}
31
- else:
32
  result = t5(input_text)
33
  return result
34
 
 
 
 
35
  @app.get("/infer_t5")
36
  def get_default_inference_endpoint():
37
  return {"message": "Use POST method to submit input data"}
@@ -55,27 +59,30 @@ except Exception as e:
55
  model = AutoModelForCausalLM.from_pretrained("miqudev/miqu-1-70b", config=model_config).to(device)
56
  tokenizer = AutoTokenizer.from_pretrained("miqudev/miqu-1-70b")
57
 
58
- def miuk_answer(query: str):
59
  query_tokens = tokenizer.encode(query, return_tensors="pt")
60
  query_tokens = query_tokens.to(device)
61
  answer = model.generate(query_tokens, max_length=128, temperature=1, pad_token_id=tokenizer.pad_token_id)
62
  return tokenizer.decode(answer[:, 0]).replace(" ", "")
63
 
64
  @app.post("/infer_miku")
65
- def infer_endpoint(data: dict):
66
  """Receive input and generate text."""
67
- input_text = data.get("input")
 
 
 
 
68
 
69
- # Validate that the input is a string
70
- assert isinstance(input_text, str), "Input must be a string."
71
 
72
- if input_text is None:
73
- return {"error": "No input text detected."}
74
- else:
75
  result = {"output": miuk_answer(input_text)}
76
  return result
77
 
 
 
 
78
  @app.get("/infer_miku")
79
  def get_default_inference_endpoint():
80
- return {"message": "Use POST method to submit input data"}
81
-
 
1
+ from fastapi import FastAPI, Request
2
+ from fastapi.responses import FileResponse, JSONResponse
3
  from fastapi.staticfiles import StaticFiles
4
  from transformers import pipeline
5
  from pydantic import BaseModel
6
+ from typing import Optional, Any
7
  import torch
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
9
 
 
14
 
15
  pipe_flan = pipeline("text2text-generation", model="google/flan-t5-small")
16
 
17
+ def t5(input: str) -> dict[str, str]:
18
  output = pipe_flan(input)
19
+ return {"output": output[0].get("generated_text", "")}
20
 
21
  @app.post("/infer_t5")
22
+ async def infer_endpoint(data: dict = Depends(parse_raw)):
23
  """Receive input and generate text."""
24
+ try:
25
+ input_text = data.get("input")
26
 
27
+ # Validate that the input is a string
28
+ assert isinstance(input_text, str), "Input must be a string."
29
+
30
+ if input_text is None:
31
+ return {"error": "No input text detected."}
32
 
 
 
 
33
  result = t5(input_text)
34
  return result
35
 
36
+ except AssertionError as e:
37
+ return JSONResponse({"error": f"Invalid Input Format: {e}"}, status_code=400)
38
+
39
  @app.get("/infer_t5")
40
  def get_default_inference_endpoint():
41
  return {"message": "Use POST method to submit input data"}
 
59
  model = AutoModelForCausalLM.from_pretrained("miqudev/miqu-1-70b", config=model_config).to(device)
60
  tokenizer = AutoTokenizer.from_pretrained("miqudev/miqu-1-70b")
61
 
62
+ def miuk_answer(query: str) -> str:
63
  query_tokens = tokenizer.encode(query, return_tensors="pt")
64
  query_tokens = query_tokens.to(device)
65
  answer = model.generate(query_tokens, max_length=128, temperature=1, pad_token_id=tokenizer.pad_token_id)
66
  return tokenizer.decode(answer[:, 0]).replace(" ", "")
67
 
68
  @app.post("/infer_miku")
69
+ async def infer_endpoint(data: dict = Depends(parse_raw)):
70
  """Receive input and generate text."""
71
+ try:
72
+ input_text = data.get("input")
73
+
74
+ # Validate that the input is a string
75
+ assert isinstance(input_text, str), "Input must be a string."
76
 
77
+ if input_text is None:
78
+ return {"error": "No input text detected."}
79
 
 
 
 
80
  result = {"output": miuk_answer(input_text)}
81
  return result
82
 
83
+ except AssertionError as e:
84
+ return JSONResponse({"error": f"Invalid Input Format: {e}"}, status_code=400)
85
+
86
  @app.get("/infer_miku")
87
  def get_default_inference_endpoint():
88
+ return {"message": "Use POST method to submit input data"}