dfalbel commited on
Commit
13c1f55
β€’
1 Parent(s): 173d645

tokenize from the main session

Browse files
Files changed (2) hide show
  1. app.R +21 -10
  2. model-session.R +13 -12
app.R CHANGED
@@ -4,12 +4,13 @@ library(minhub)
4
  library(magrittr)
5
  source("model-session.R")
6
 
 
7
  repo <- "stabilityai/stablelm-tuned-alpha-3b"
8
  repo <- Sys.getenv("MODEL_REPO", unset = repo)
9
  sess <- model_session$new()
10
 
11
  max_n_tokens <- 100
12
- system_prompt = "<|SYSTEM|># StableLM Tuned (Alpha version)
13
  - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
14
  - StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
15
  - StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
@@ -34,7 +35,7 @@ ui <- page_fillable(
34
  )
35
 
36
  server <- function(input, output, session) {
37
- prompt <- reactiveVal(value = system_prompt)
38
  n_tokens <- reactiveVal(value = 0)
39
  msg_id <- reactiveVal(value = 0)
40
 
@@ -46,12 +47,21 @@ server <- function(input, output, session) {
46
  updateActionButton(inputId = "send", label = "Waiting for model...")
47
  insert_message(msg_id, as.character(glue::glue("πŸ€—: {input$prompt}")))
48
 
 
 
 
 
 
 
 
 
 
49
  # we modify the prompt to trigger the 'next_token' reactive
50
- prompt(paste0(prompt(), "<|USER|>", input$prompt, "<|ASSISTANT|>"))
51
  })
52
 
53
- next_token <- eventReactive(prompt(), ignoreInit = TRUE, {
54
- prompt() %>%
55
  sess$generate() %>%
56
  promises::then(
57
  onFulfilled = function(x) {x},
@@ -65,17 +75,18 @@ server <- function(input, output, session) {
65
 
66
  observeEvent(next_token(), {
67
  tok <- next_token()
68
-
69
  n_tokens(n_tokens() + 1)
70
  tok %>% promises::then(function(tok) {
 
71
  if (n_tokens() == 1) {
72
- insert_message(msg_id, paste0("πŸ€–: ", tok), append = FALSE)
73
  } else {
74
- insert_message(msg_id, tok, append = TRUE)
75
  }
76
 
77
- if (tok != "" && n_tokens() < max_n_tokens) {
78
- prompt(paste0(prompt(), tok))
 
79
  } else {
80
  shinyjs::enable("send")
81
  updateActionButton(inputId = "send", label = "Send")
 
4
  library(magrittr)
5
  source("model-session.R")
6
 
7
+ repo <- "EleutherAI/pythia-70m"
8
  repo <- "stabilityai/stablelm-tuned-alpha-3b"
9
  repo <- Sys.getenv("MODEL_REPO", unset = repo)
10
  sess <- model_session$new()
11
 
12
  max_n_tokens <- 100
13
+ system_prompt <- "<|SYSTEM|># StableLM Tuned (Alpha version)
14
  - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
15
  - StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
16
  - StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
 
35
  )
36
 
37
  server <- function(input, output, session) {
38
+ idxs <- reactiveVal()
39
  n_tokens <- reactiveVal(value = 0)
40
  msg_id <- reactiveVal(value = 0)
41
 
 
47
  updateActionButton(inputId = "send", label = "Waiting for model...")
48
  insert_message(msg_id, as.character(glue::glue("πŸ€—: {input$prompt}")))
49
 
50
+ if (is.null(idxs())) {
51
+ current_idxs <- sess$tok$encode(system_prompt)$ids
52
+ } else {
53
+ current_idxs <- idxs()
54
+ }
55
+
56
+ new_idxs <- paste0("<|USER|>", input$prompt, "<|ASSISTANT|>")
57
+ new_idxs <- sess$tok$encode(new_idxs)$ids
58
+
59
  # we modify the prompt to trigger the 'next_token' reactive
60
+ idxs(c(current_idxs, new_idxs))
61
  })
62
 
63
+ next_token <- eventReactive(idxs(), ignoreInit = TRUE, {
64
+ idxs() %>%
65
  sess$generate() %>%
66
  promises::then(
67
  onFulfilled = function(x) {x},
 
75
 
76
  observeEvent(next_token(), {
77
  tok <- next_token()
 
78
  n_tokens(n_tokens() + 1)
79
  tok %>% promises::then(function(tok) {
80
+ tok_dec <- sess$tok$decode(tok)
81
  if (n_tokens() == 1) {
82
+ insert_message(msg_id, paste0("πŸ€–: ", tok_dec), append = FALSE)
83
  } else {
84
+ insert_message(msg_id, tok_dec, append = TRUE)
85
  }
86
 
87
+ if ((!tok %in% c(50278L, 50279L, 50277L, 1L, 0L)) &&
88
+ n_tokens() < max_n_tokens) {
89
+ idxs(c(idxs(), tok))
90
  } else {
91
  shinyjs::enable("send")
92
  updateActionButton(inputId = "send", label = "Send")
model-session.R CHANGED
@@ -13,37 +13,38 @@ model_session <- R6::R6Class(
13
  cat("Model is already loaded.", "\n")
14
  return(self$task_q$push(function() "done"))
15
  }
 
 
16
  self$task_q <- callq::task_q$new(num_workers = 1)
17
  self$task_q$push(args = list(repo = repo), function(repo) {
18
  library(torch)
19
  library(zeallot)
20
  library(minhub)
21
- device <- if (cuda_is_available()) "cuda" else "cpu"
22
  model <<- minhub::gptneox_from_pretrained(repo)
23
  model$eval()
24
- if (device == "cuda") {
25
- model$to(device=device)
26
- #model$to(dtype=torch_float())
27
- } else {
28
  model$to(dtype = torch_float())
29
- }
30
- tok <<- tok::tokenizer$from_pretrained(repo)
31
  "done"
32
  })
33
  },
34
- generate = function(prompt) {
35
  if (is.null(self$task_q)) {
36
  cat("Model is not loaded, error.", "\n")
37
  return(self$task_q$push(function() stop("Model is not loaded")))
38
  }
39
  args <- list(
40
- prompt = prompt,
41
  temperature = self$temperature,
42
  top_k = self$top_k
43
  )
44
- self$task_q$push(args = args, function(prompt, temperature, top_k) {
45
  device <- if (cuda_is_available()) "cuda" else "cpu"
46
- idx <- torch_tensor(tok$encode(prompt)$ids, device=device)$view(c(1, -1))
47
  with_no_grad({
48
  logits <- model(idx + 1L)$to(dtype="float", device="cpu")
49
  })
@@ -52,7 +53,7 @@ model_session <- R6::R6Class(
52
  logits <- torch_full_like(logits, -1e7)$scatter_(-1, ind, prob)
53
  logits <- nnf_softmax(logits, dim = -1)
54
  id_next <- torch::torch_multinomial(logits, num_samples = 1)$cpu() - 1L
55
- tok$decode(as.integer(id_next))
56
  })
57
  }
58
  )
 
13
  cat("Model is already loaded.", "\n")
14
  return(self$task_q$push(function() "done"))
15
  }
16
+ # the tokenizer doesn't need to live in the remote session.
17
+ self$tok <- tok::tokenizer$from_pretrained(repo)
18
  self$task_q <- callq::task_q$new(num_workers = 1)
19
  self$task_q$push(args = list(repo = repo), function(repo) {
20
  library(torch)
21
  library(zeallot)
22
  library(minhub)
23
+ device <- if (cuda_is_available()) "cuda" else "cpu"
24
  model <<- minhub::gptneox_from_pretrained(repo)
25
  model$eval()
26
+ if (device == "cuda") {
27
+ model$to(dtype=torch_half())
28
+ model$to(device=device)
29
+ } else {
30
  model$to(dtype = torch_float())
31
+ }
 
32
  "done"
33
  })
34
  },
35
+ generate = function(idx) {
36
  if (is.null(self$task_q)) {
37
  cat("Model is not loaded, error.", "\n")
38
  return(self$task_q$push(function() stop("Model is not loaded")))
39
  }
40
  args <- list(
41
+ idx = idx,
42
  temperature = self$temperature,
43
  top_k = self$top_k
44
  )
45
+ self$task_q$push(args = args, function(idx, temperature, top_k) {
46
  device <- if (cuda_is_available()) "cuda" else "cpu"
47
+ idx <- torch_tensor(idx, device=device)$view(c(1, -1))
48
  with_no_grad({
49
  logits <- model(idx + 1L)$to(dtype="float", device="cpu")
50
  })
 
53
  logits <- torch_full_like(logits, -1e7)$scatter_(-1, ind, prob)
54
  logits <- nnf_softmax(logits, dim = -1)
55
  id_next <- torch::torch_multinomial(logits, num_samples = 1)$cpu() - 1L
56
+ as.integer(id_next)
57
  })
58
  }
59
  )