Bunpheng commited on
Commit
adf0c8a
1 Parent(s): c73b3ce

Update download_model.py

Browse files
Files changed (1) hide show
  1. download_model.py +27 -10
download_model.py CHANGED
@@ -1,16 +1,33 @@
1
- # download_model.py
 
2
  from huggingface_hub import snapshot_download
3
  from pathlib import Path
 
4
 
5
- def download_model():
6
- mistral_models_path = Path.home().joinpath('mistral_models', '7B-v0.3')
7
- mistral_models_path.mkdir(parents=True, exist_ok=True)
8
-
 
 
 
 
9
  snapshot_download(
10
- repo_id="mistralai/Mistral-7B-v0.3",
11
- allow_patterns=["params.json", "consolidated.safetensors", "tokenizer.model.v3"],
12
- local_dir=mistral_models_path
 
13
  )
14
 
15
- if __name__ == "__main__":
16
- download_model()
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from transformers import pipeline
3
  from huggingface_hub import snapshot_download
4
  from pathlib import Path
5
+ import os
6
 
7
+ app = FastAPI()
8
+
9
+ # Define the path where the model will be downloaded
10
+ mistral_models_path = Path.home().joinpath('mistral_models', '7B-v0.3')
11
+ mistral_models_path.mkdir(parents=True, exist_ok=True)
12
+
13
+ # Download the model if not already present
14
+ if not (mistral_models_path / "params.json").exists():
15
  snapshot_download(
16
+ repo_id="mistralai/Mistral-7B-v0.3",
17
+ allow_patterns=["params.json", "consolidated.safetensors", "tokenizer.model.v3"],
18
+ local_dir=mistral_models_path,
19
+ token=os.getenv('HUGGINGFACE_HUB_TOKEN') # Use the environment variable for authentication
20
  )
21
 
22
+ # Load the model
23
+ pipe_mistral = pipeline("text2text-generation", model=str(mistral_models_path))
24
+
25
+ @app.get("/mistral")
26
+ def mistral_endpoint(input: str):
27
+ output = pipe_mistral(input)
28
+ return {"output": output[0]["generated_text"]}
29
+
30
+ @app.get("/")
31
+ def greet_json():
32
+ return {"Hello": "World!"}
33
+