Empereur-Pirate commited on
Commit
79e06e3
1 Parent(s): 44a140e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +39 -67
main.py CHANGED
@@ -7,89 +7,61 @@ from typing import Optional, Any
7
  import torch
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
9
 
10
- # Helper function to read raw request bodies
11
- async def parse_raw(request: Request):
12
- return await request.body()
13
-
14
  app = FastAPI()
15
 
16
- # Serve the static files
17
- app.mount("/static", StaticFiles(directory="static"), name="static")
 
 
18
 
19
- pipe_flan = pipeline("text2text-generation", model="google/flan-t5-small")
20
 
21
- def t5(input: str) -> dict[str, str]:
22
- output = pipe_flan(input)
23
- return {"output": output[0].get("generated_text", "")}
24
 
25
- class ParseRaw(BaseModel):
26
- raw: bytes
 
27
 
28
- @app .post("/infer_t5")
29
- async def infer_endpoint(data: ParseRaw = Depends(parse_raw)):
30
- """Receive input and generate text."""
31
  try:
32
- input_text = data.raw.decode("utf-8")
33
-
34
- # Validate that the input is a string
35
- assert isinstance(input_text, str), "Input must be a string."
36
-
37
- if input_text is None:
38
- return {"error": "No input text detected."}
39
-
40
- result = t5(input_text)
41
- return result
42
 
43
- except AssertionError as e:
44
- return JSONResponse({"error": f"Invalid Input Format: {e}"}, status_code=400)
 
45
 
46
- @app .get("/infer_t5")
47
- def get_default_inference_endpoint():
48
- return {"message": "Use POST method to submit input data"}
49
-
50
- # Initialize device
51
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52
-
53
- try:
54
- # Initiate the model and tokenizer with the corrected pre-trained weights
55
- model_config = AutoConfig.from_pretrained("152334H/miqu-1-70b-sf", trust_remote_code=True)
56
- model = AutoModelForCausalLM.from_pretrained("152334H/miqu-1-70b-sf", config=model_config).to(device)
57
- tokenizer = AutoTokenizer.from_pretrained("152334H/miqu-1-70b-sf")
58
- except Exception as e:
59
- print("[WARNING]: Failed to load model and tokenizer conventionally.")
60
- print(f"Exception: {e}")
61
-
62
- # Configure a fallback mechanism similar to the original implementation
63
- model_config = AutoConfig.from_pretrained("152334H/miqu-1-70b-sf", trust_remote_code=True)
64
 
65
- # Load the model using the fallback configuration
66
- model = AutoModelForCausalLM.from_pretrained("152334H/miqu-1-70b-sf", config=model_config).to(device)
67
- tokenizer = AutoTokenizer.from_pretrained("152334H/miqu-1-70b-sf")
68
-
69
- def miuk_answer(query: str) -> str:
70
  query_tokens = tokenizer.encode(query, return_tensors="pt")
71
  query_tokens = query_tokens.to(device)
72
  answer = model.generate(query_tokens, max_length=128, temperature=1, pad_token_id=tokenizer.pad_token_id)
73
- return tokenizer.decode(answer[:, 0]).replace(" ", "")
74
 
75
- @app .post("/infer_miku")
76
- async def infer_endpoint(data: ParseRaw = Depends(parse_raw)):
77
- """Receive input and generate text."""
78
- try:
79
- input_text = data.raw.decode("utf-8")
80
 
81
- # Validate that the input is a string
82
- assert isinstance(input_text, str), "Input must be a string."
83
 
84
- if input_text is None:
85
- return {"error": "No input text detected."}
86
 
87
- result = {"output": miuk_answer(input_text)}
88
- return result
 
89
 
90
- except AssertionError as e:
91
- return JSONResponse({"error": f"Invalid Input Format: {e}"}, status_code=400)
92
 
93
- @app .get("/infer_miku")
94
- def get_default_inference_endpoint():
95
- return {"message": "Use POST method to submit input data"}
 
7
  import torch
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
9
 
 
 
 
 
10
  app = FastAPI()
11
 
12
+ # Initialize device
13
+ def initialize_device():
14
+ global device
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
+ initialize_device()
18
 
19
+ # Helper function to read raw request bodies
20
+ async def parse_raw(request: Request):
21
+ return await request.body()
22
 
23
+ # Initialize the model and tokenizer with the corrected pre-trained weights
24
+ def init_corrected_model():
25
+ global model_config, model, tokenizer
26
 
 
 
 
27
  try:
28
+ model_config = AutoConfig.from_pretrained("152334H/miqu-1-70b-sf", trust_remote_code=True)
29
+ model = AutoModelForCausalLM.from_pretrained("152334H/miqu-1-70b-sf", config=model_config).to(device)
30
+ tokenizer = AutoTokenizer.from_pretrained("152334H/miqu-1-70b-sf")
31
+ except Exception as e:
32
+ print("[WARNING]: Failed to load model and tokenizer conventionally.")
33
+ print(f"Exception: {e}")
 
 
 
 
34
 
35
+ model_config = AutoConfig.from_pretrained("152334H/miqu-1-70b-sf", trust_remote_code=True)
36
+ model = AutoModelForCausalLM.from_pretrained("152334H/miqu-1-70b-sf", config=model_config).to(device)
37
+ tokenizer = AutoTokenizer.from_pretrained("152334H/miqu-1-70b-sf")
38
 
39
+ init_corrected_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ # Utility function to generate answers from the model
42
+ def miuk_answer(query: str) -> dict[str, str]:
 
 
 
43
  query_tokens = tokenizer.encode(query, return_tensors="pt")
44
  query_tokens = query_tokens.to(device)
45
  answer = model.generate(query_tokens, max_length=128, temperature=1, pad_token_id=tokenizer.pad_token_id)
46
+ return {"output": tokenizer.decode(answer[:, 0])}
47
 
48
+ # Endpoint handler to receive incoming queries and pass them to the utility function for processing
49
+ @app.post("/infer_miku")
50
+ async def infer_endpoint(data: BaseModel = Depends(parse_raw)):
51
+ input_text = data.raw.decode("utf-8")
 
52
 
53
+ if input_text is None or len(input_text) == 0:
54
+ return JSONResponse({"error": "Empty input received."}, status_code=400)
55
 
56
+ result = miuk_answer(input_text)
57
+ return result
58
 
59
+ @app.get("/infer_miku")
60
+ def get_default_inference_endpoint():
61
+ return {"message": "Use POST method to submit input data"}
62
 
63
+ # Mount static files
64
+ app.mount("/static", StaticFiles(directory="static"), name="static")
65
 
66
+ # Initialization done
67
+ print("Initialization Complete.")