starchat-ggml / main.py
matthoffner's picture
Update main.py
3366fc4
raw
history blame
4.64 kB
import fastapi
import markdown
import uvicorn
from ctransformers import AutoModelForCausalLM
from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from sse_starlette.sse import EventSourceResponse
from pydantic import BaseModel
llm = AutoModelForCausalLM.from_pretrained("NeoDim/starchat-alpha-GGML",
model_file="starchat-alpha-ggml-q4_0.bin",
model_type="starcoder")
app = fastapi.FastAPI(title="Starchat Alpha")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def index():
with open("README.md", "r", encoding="utf-8") as readme_file:
md_template_string = readme_file.read()
html_content = markdown.markdown(md_template_string)
return HTMLResponse(content=html_content, status_code=200)
@app.get("/demo")
async def demo():
html_content = """
<!DOCTYPE html>
<html>
<head>
<script src="https://cdnjs.cloudflare.com/ajax/libs/showdown/1.9.1/showdown.min.js"></script>
</head>
<body>
<style>
body {
font-family: -apple-system,BlinkMacSystemFont,"Segoe UI",Helvetica,Arial,sans-serif,"Apple Color Emoji","Segoe UI Emoji","Segoe UI Symbol";
}
code {
font-family: "SFMono-Regular",Consolas,"Liberation Mono",Menlo,Courier,monospace !important;
display: inline-block;
background-color: lightgray;
}
h1 h2 h3 h4 h5 h6 {
font-family: Roboto,-apple-system,BlinkMacSystemFont,"Helvetica Neue","Segoe UI","Oxygen","Ubuntu","Cantarell","Open Sans",sans-serif;
}
#content {
box-sizing: border-box;
min-width: 200px;
max-width: 980px;
margin: 0 auto;
padding: 45px;
font-size: 16px;
}
@media (max-width: 767px) {
#content {
padding: 15px;
}
}
</style>
<script type="module" src="https://cdn.skypack.dev/@vanillawc/wc-markdown"></script>
<wc-markdown id="content" highlight><h1>starchat-alpha-q4.0</h1></wc-markdown>
<script>
var converter = new showdown.Converter();
var source = new EventSource("https://matthoffner-starchat-alpha.hf.space/stream");
let eventCache;
source.onmessage = function(event) {
let eventData = event.data;
console.log(eventData);
if (eventData.includes("```")) {
eventCache = true;
return;
}
if (eventCache && !eventData.includes("```")) {
backticks = "```";
eventData = `${backticks}${eventData}<br /><code>`;
eventCache = false;
}
if (eventData === ":") {
eventData = `${eventData}<br />`;
}
if (eventData === "<|assistant|>") {
eventData = `<br />${eventData}`;
}
if (eventData === "<|end|>") {
eventData = "<br />";
}
document.getElementById("content").innerHTML = document.getElementById("content").innerHTML + eventData;
};
</script>
</body>
</html>
"""
return HTMLResponse(content=html_content, status_code=200)
@app.get("/stream")
async def chat(prompt = "<|user|> Write an express server with server sent events. <|assistant|>"):
tokens = llm.tokenize(prompt)
async def server_sent_events(chat_chunks, llm):
yield prompt
for chat_chunk in llm.generate(chat_chunks):
yield llm.detokenize(chat_chunk)
yield ""
return EventSourceResponse(server_sent_events(tokens, llm))
@app.post("/v1/chat/completions")
async def chat(request, response_mode=None):
tokens = llm.tokenize(request.messages)
async def server_sent_events(chat_chunks, llm):
for token in llm.generate(chat_chunks):
yield llm.detokenize(token)
yield ""
return EventSourceResponse(server_sent_events(tokens, llm))
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)