Spaces:
Runtime error
Runtime error
tokenize from the main session
Browse files- app.R +21 -10
- model-session.R +13 -12
app.R
CHANGED
@@ -4,12 +4,13 @@ library(minhub)
|
|
4 |
library(magrittr)
|
5 |
source("model-session.R")
|
6 |
|
|
|
7 |
repo <- "stabilityai/stablelm-tuned-alpha-3b"
|
8 |
repo <- Sys.getenv("MODEL_REPO", unset = repo)
|
9 |
sess <- model_session$new()
|
10 |
|
11 |
max_n_tokens <- 100
|
12 |
-
system_prompt
|
13 |
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
|
14 |
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
15 |
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
|
@@ -34,7 +35,7 @@ ui <- page_fillable(
|
|
34 |
)
|
35 |
|
36 |
server <- function(input, output, session) {
|
37 |
-
|
38 |
n_tokens <- reactiveVal(value = 0)
|
39 |
msg_id <- reactiveVal(value = 0)
|
40 |
|
@@ -46,12 +47,21 @@ server <- function(input, output, session) {
|
|
46 |
updateActionButton(inputId = "send", label = "Waiting for model...")
|
47 |
insert_message(msg_id, as.character(glue::glue("π€: {input$prompt}")))
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
# we modify the prompt to trigger the 'next_token' reactive
|
50 |
-
|
51 |
})
|
52 |
|
53 |
-
next_token <- eventReactive(
|
54 |
-
|
55 |
sess$generate() %>%
|
56 |
promises::then(
|
57 |
onFulfilled = function(x) {x},
|
@@ -65,17 +75,18 @@ server <- function(input, output, session) {
|
|
65 |
|
66 |
observeEvent(next_token(), {
|
67 |
tok <- next_token()
|
68 |
-
|
69 |
n_tokens(n_tokens() + 1)
|
70 |
tok %>% promises::then(function(tok) {
|
|
|
71 |
if (n_tokens() == 1) {
|
72 |
-
insert_message(msg_id, paste0("π€: ",
|
73 |
} else {
|
74 |
-
insert_message(msg_id,
|
75 |
}
|
76 |
|
77 |
-
if (tok
|
78 |
-
|
|
|
79 |
} else {
|
80 |
shinyjs::enable("send")
|
81 |
updateActionButton(inputId = "send", label = "Send")
|
|
|
4 |
library(magrittr)
|
5 |
source("model-session.R")
|
6 |
|
7 |
+
repo <- "EleutherAI/pythia-70m"
|
8 |
repo <- "stabilityai/stablelm-tuned-alpha-3b"
|
9 |
repo <- Sys.getenv("MODEL_REPO", unset = repo)
|
10 |
sess <- model_session$new()
|
11 |
|
12 |
max_n_tokens <- 100
|
13 |
+
system_prompt <- "<|SYSTEM|># StableLM Tuned (Alpha version)
|
14 |
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
|
15 |
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
16 |
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
|
|
|
35 |
)
|
36 |
|
37 |
server <- function(input, output, session) {
|
38 |
+
idxs <- reactiveVal()
|
39 |
n_tokens <- reactiveVal(value = 0)
|
40 |
msg_id <- reactiveVal(value = 0)
|
41 |
|
|
|
47 |
updateActionButton(inputId = "send", label = "Waiting for model...")
|
48 |
insert_message(msg_id, as.character(glue::glue("π€: {input$prompt}")))
|
49 |
|
50 |
+
if (is.null(idxs())) {
|
51 |
+
current_idxs <- sess$tok$encode(system_prompt)$ids
|
52 |
+
} else {
|
53 |
+
current_idxs <- idxs()
|
54 |
+
}
|
55 |
+
|
56 |
+
new_idxs <- paste0("<|USER|>", input$prompt, "<|ASSISTANT|>")
|
57 |
+
new_idxs <- sess$tok$encode(new_idxs)$ids
|
58 |
+
|
59 |
# we modify the prompt to trigger the 'next_token' reactive
|
60 |
+
idxs(c(current_idxs, new_idxs))
|
61 |
})
|
62 |
|
63 |
+
next_token <- eventReactive(idxs(), ignoreInit = TRUE, {
|
64 |
+
idxs() %>%
|
65 |
sess$generate() %>%
|
66 |
promises::then(
|
67 |
onFulfilled = function(x) {x},
|
|
|
75 |
|
76 |
observeEvent(next_token(), {
|
77 |
tok <- next_token()
|
|
|
78 |
n_tokens(n_tokens() + 1)
|
79 |
tok %>% promises::then(function(tok) {
|
80 |
+
tok_dec <- sess$tok$decode(tok)
|
81 |
if (n_tokens() == 1) {
|
82 |
+
insert_message(msg_id, paste0("π€: ", tok_dec), append = FALSE)
|
83 |
} else {
|
84 |
+
insert_message(msg_id, tok_dec, append = TRUE)
|
85 |
}
|
86 |
|
87 |
+
if ((!tok %in% c(50278L, 50279L, 50277L, 1L, 0L)) &&
|
88 |
+
n_tokens() < max_n_tokens) {
|
89 |
+
idxs(c(idxs(), tok))
|
90 |
} else {
|
91 |
shinyjs::enable("send")
|
92 |
updateActionButton(inputId = "send", label = "Send")
|
model-session.R
CHANGED
@@ -13,37 +13,38 @@ model_session <- R6::R6Class(
|
|
13 |
cat("Model is already loaded.", "\n")
|
14 |
return(self$task_q$push(function() "done"))
|
15 |
}
|
|
|
|
|
16 |
self$task_q <- callq::task_q$new(num_workers = 1)
|
17 |
self$task_q$push(args = list(repo = repo), function(repo) {
|
18 |
library(torch)
|
19 |
library(zeallot)
|
20 |
library(minhub)
|
21 |
-
|
22 |
model <<- minhub::gptneox_from_pretrained(repo)
|
23 |
model$eval()
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
model$to(dtype = torch_float())
|
29 |
-
|
30 |
-
tok <<- tok::tokenizer$from_pretrained(repo)
|
31 |
"done"
|
32 |
})
|
33 |
},
|
34 |
-
generate = function(
|
35 |
if (is.null(self$task_q)) {
|
36 |
cat("Model is not loaded, error.", "\n")
|
37 |
return(self$task_q$push(function() stop("Model is not loaded")))
|
38 |
}
|
39 |
args <- list(
|
40 |
-
|
41 |
temperature = self$temperature,
|
42 |
top_k = self$top_k
|
43 |
)
|
44 |
-
self$task_q$push(args = args, function(
|
45 |
device <- if (cuda_is_available()) "cuda" else "cpu"
|
46 |
-
|
47 |
with_no_grad({
|
48 |
logits <- model(idx + 1L)$to(dtype="float", device="cpu")
|
49 |
})
|
@@ -52,7 +53,7 @@ model_session <- R6::R6Class(
|
|
52 |
logits <- torch_full_like(logits, -1e7)$scatter_(-1, ind, prob)
|
53 |
logits <- nnf_softmax(logits, dim = -1)
|
54 |
id_next <- torch::torch_multinomial(logits, num_samples = 1)$cpu() - 1L
|
55 |
-
|
56 |
})
|
57 |
}
|
58 |
)
|
|
|
13 |
cat("Model is already loaded.", "\n")
|
14 |
return(self$task_q$push(function() "done"))
|
15 |
}
|
16 |
+
# the tokenizer doesn't need to live in the remote session.
|
17 |
+
self$tok <- tok::tokenizer$from_pretrained(repo)
|
18 |
self$task_q <- callq::task_q$new(num_workers = 1)
|
19 |
self$task_q$push(args = list(repo = repo), function(repo) {
|
20 |
library(torch)
|
21 |
library(zeallot)
|
22 |
library(minhub)
|
23 |
+
device <- if (cuda_is_available()) "cuda" else "cpu"
|
24 |
model <<- minhub::gptneox_from_pretrained(repo)
|
25 |
model$eval()
|
26 |
+
if (device == "cuda") {
|
27 |
+
model$to(dtype=torch_half())
|
28 |
+
model$to(device=device)
|
29 |
+
} else {
|
30 |
model$to(dtype = torch_float())
|
31 |
+
}
|
|
|
32 |
"done"
|
33 |
})
|
34 |
},
|
35 |
+
generate = function(idx) {
|
36 |
if (is.null(self$task_q)) {
|
37 |
cat("Model is not loaded, error.", "\n")
|
38 |
return(self$task_q$push(function() stop("Model is not loaded")))
|
39 |
}
|
40 |
args <- list(
|
41 |
+
idx = idx,
|
42 |
temperature = self$temperature,
|
43 |
top_k = self$top_k
|
44 |
)
|
45 |
+
self$task_q$push(args = args, function(idx, temperature, top_k) {
|
46 |
device <- if (cuda_is_available()) "cuda" else "cpu"
|
47 |
+
idx <- torch_tensor(idx, device=device)$view(c(1, -1))
|
48 |
with_no_grad({
|
49 |
logits <- model(idx + 1L)$to(dtype="float", device="cpu")
|
50 |
})
|
|
|
53 |
logits <- torch_full_like(logits, -1e7)$scatter_(-1, ind, prob)
|
54 |
logits <- nnf_softmax(logits, dim = -1)
|
55 |
id_next <- torch::torch_multinomial(logits, num_samples = 1)$cpu() - 1L
|
56 |
+
as.integer(id_next)
|
57 |
})
|
58 |
}
|
59 |
)
|