jbilcke-hf HF staff commited on
Commit
b0f34ee
·
1 Parent(s): bd74bd1

added the model picker

Browse files
.env CHANGED
@@ -37,4 +37,4 @@ COMMUNITY_API_URL="https://jbilcke-hf-community.hf.space"
37
  COMMUNITY_API_TOKEN=""
38
 
39
  # must be unique per app
40
- COMMUNITY_APP_ID=""
 
37
  COMMUNITY_API_TOKEN=""
38
 
39
  # must be unique per app
40
+ COMMUNITY_API_ID=""
src/app/interface/generate/index.tsx CHANGED
@@ -1,6 +1,6 @@
1
  "use client"
2
 
3
- import { useState, useTransition } from "react"
4
  import { useSpring, animated } from "@react-spring/web"
5
  import { usePathname, useRouter, useSearchParams } from "next/navigation"
6
 
@@ -11,6 +11,8 @@ import { generateAnimation } from "@/app/server/actions/animation"
11
  import { postToCommunity } from "@/app/server/actions/community"
12
  import { useCountdown } from "@/lib/useCountdown"
13
  import { Countdown } from "../countdown"
 
 
14
 
15
  export function Generate() {
16
  const router = useRouter()
@@ -18,15 +20,26 @@ export function Generate() {
18
  const searchParams = useSearchParams()
19
  const [_isPending, startTransition] = useTransition()
20
 
 
 
21
  const [isLocked, setLocked] = useState(false)
22
  const [promptDraft, setPromptDraft] = useState("")
23
  const [assetUrl, setAssetUrl] = useState("")
24
  const [isOverSubmitButton, setOverSubmitButton] = useState(false)
25
 
 
 
 
26
  const [runs, setRuns] = useState(0)
 
 
 
 
 
27
  const { progressPercent, remainingTimeInSec } = useCountdown({
 
28
  timerId: runs, // everytime we change this, the timer will reset
29
- durationInSec: 40,
30
  onEnd: () => {}
31
  })
32
 
@@ -54,11 +67,27 @@ export function Generate() {
54
  console.log("handleSubmit:", { isLocked, promptDraft })
55
  if (isLocked) { return }
56
  if (!promptDraft) { return }
 
 
57
  setRuns(runs + 1)
58
  setLocked(true)
 
 
 
 
 
 
59
  startTransition(async () => {
60
- const huggingFaceLora = "KappaNeuro/studio-ghibli-style"
61
- const triggerWord = "Studio Ghibli Style"
 
 
 
 
 
 
 
 
62
  try {
63
  console.log("starting transition, calling generateAnimation")
64
  const newAssetUrl = await generateAnimation({
@@ -66,33 +95,11 @@ export function Generate() {
66
  negativePrompt: "",
67
  huggingFaceLora,
68
  triggerWord,
69
- // huggingFaceLora: "veryVANYA/ps1-graphics-sdxl-v2", //
70
- // huggingFaceLora: "ostris/crayon_style_lora_sdxl", // "https://huggingface.co/ostris/crayon_style_lora_sdxl/resolve/main/crayons_v1_sdxl.safetensors",
71
- // replicateLora: "https://replicate.com/jbilcke/sdxl-panorama",
72
-
73
- // ---- replicate models -----
74
- // use: "in the style of TOK" in the prompt!
75
- // or this? "in the style of <s0><s1>"
76
- // I don't see a lot of diff
77
- //
78
- // Zelda BOTW
79
- // replicateLora: "https://pbxt.replicate.delivery/8UkalcGbGnrNHxGeqeCrhKcPbrRDlx4vLToRRlUWqzpnfieFB/trained_model.tar",
80
-
81
- // Zelda64
82
- // replicateLora: "https://pbxt.replicate.delivery/HPZlvCwDWtb5KpefUUcofwvZwTbrZAH9oLvzrn24hqUcQBfFB/trained_model.tar",
83
-
84
- // panorama lora
85
- // replicateLora: "https://pbxt.replicate.delivery/nuXez5QNfEmhPk1TLGtl8Q0TwyucZbzTsfUe1ibUfNV0JrMMC/trained_model.tar",
86
-
87
- // foundation
88
- // replicateLora: "https://pbxt.replicate.delivery/VHU109Irgh6EPJrZ7aVScvadYDqXhlL3izfEAjfhs8Cvz0hRA/trained_model.tar",
89
-
90
- size: "672x384", // "1024x512", // "512x512" // "320x768"
91
-
92
  nbFrames: 8, // if duration is 1000ms then it means 8 FPS
93
  duration: 1000, // in ms
94
  steps: 25,
95
- })
96
  setAssetUrl(newAssetUrl)
97
 
98
  try {
@@ -111,7 +118,7 @@ export function Generate() {
111
  const search = current.toString()
112
  router.push(`${pathname}${search ? `?${search}` : ""}`)
113
  } catch (err) {
114
- console.error(`not a blocked, but we failed to post to the community (reason: ${err})`)
115
  }
116
  } catch (err) {
117
  console.error(err)
@@ -121,6 +128,18 @@ export function Generate() {
121
  })
122
  }
123
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  return (
125
  <div className={cn(
126
  `fixed inset-0 w-screen h-screen`,
@@ -137,17 +156,22 @@ export function Generate() {
137
  `w-full md:max-w-4xl lg:max-w-5xl xl:max-w-6xl max-h-[80vh]`,
138
  `space-y-3 md:space-y-0 md:space-x-6`,
139
  `transition-all duration-300 ease-in-out`,
140
-
141
  )}>
142
- <div className={cn(
 
 
143
  `flex flex-col`,
144
  `flex-grow rounded-2xl md:rounded-3xl`,
145
  `backdrop-blur-lg bg-white/40`,
146
  `border-2 border-white/10`,
147
  `items-center`,
148
  `space-y-6 md:space-y-8 lg:space-y-12 xl:space-y-16`,
149
- `px-3 py-6 md:px-6 md:py-12 xl:px-8 xl:py-16`
150
- )}>
 
 
 
 
151
 
152
  {assetUrl ? <div
153
  className={cn(
@@ -226,40 +250,71 @@ export function Generate() {
226
  </div>
227
  </div>
228
  <div className="flex flex-row w-52">
229
- <animated.button
230
- style={{
231
- textShadow: "0px 0px 1px #000000ab",
232
- ...submitButtonBouncer
233
- }}
234
- onMouseEnter={() => setOverSubmitButton(true)}
235
- onMouseLeave={() => setOverSubmitButton(false)}
236
- className={cn(
237
- `px-6 py-3`,
238
- `rounded-full`,
239
- `transition-all duration-300 ease-in-out`,
240
- isLocked
241
- ? `bg-orange-500/20 border-orange-800/10`
242
- : `bg-sky-500/80 hover:bg-sky-400/100 border-sky-800/20`,
243
- `text-center`,
244
- `w-full`,
245
- `text-2xl text-sky-50`,
246
- `border`,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  headingFont.className,
248
- // `transition-all duration-300`,
249
- // `hover:animate-bounce`
250
- )}
251
- disabled={isLocked}
252
- onClick={handleSubmit}
253
- >
254
- {isLocked ? "Generating.." : "Generate"}
255
- </animated.button>
256
  </div>
257
- <div>
258
- Pick a model..
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  </div>
260
  </div>
261
 
262
- <div>
263
  <p>Community creations</p>
264
  <div>
265
  <div>A</div>
@@ -268,7 +323,8 @@ export function Generate() {
268
  <div>D</div>
269
  <div>E</div>
270
  </div>
271
- </div>
 
272
  </div>
273
  </div>
274
  </div>
 
1
  "use client"
2
 
3
+ import { useEffect, useRef, useState, useTransition } from "react"
4
  import { useSpring, animated } from "@react-spring/web"
5
  import { usePathname, useRouter, useSearchParams } from "next/navigation"
6
 
 
11
  import { postToCommunity } from "@/app/server/actions/community"
12
  import { useCountdown } from "@/lib/useCountdown"
13
  import { Countdown } from "../countdown"
14
+ import { getSDXLModels } from "@/app/server/actions/models"
15
+ import { SDXLModel } from "@/types"
16
 
17
  export function Generate() {
18
  const router = useRouter()
 
20
  const searchParams = useSearchParams()
21
  const [_isPending, startTransition] = useTransition()
22
 
23
+ const scrollRef = useRef<HTMLDivElement>(null)
24
+
25
  const [isLocked, setLocked] = useState(false)
26
  const [promptDraft, setPromptDraft] = useState("")
27
  const [assetUrl, setAssetUrl] = useState("")
28
  const [isOverSubmitButton, setOverSubmitButton] = useState(false)
29
 
30
+ const [models, setModels] = useState<SDXLModel[]>([])
31
+ const [selectedModel, setSelectedModel] = useState<SDXLModel>()
32
+
33
  const [runs, setRuns] = useState(0)
34
+ const runsRef = useRef(0)
35
+ const [showModels, setShowModels] = useState(true)
36
+ // useEffect(() => { runsRef.current = runs }, [runs])
37
+
38
+ console.log("runs:", runs)
39
  const { progressPercent, remainingTimeInSec } = useCountdown({
40
+ isActive: isLocked,
41
  timerId: runs, // everytime we change this, the timer will reset
42
+ durationInSec: 45,
43
  onEnd: () => {}
44
  })
45
 
 
67
  console.log("handleSubmit:", { isLocked, promptDraft })
68
  if (isLocked) { return }
69
  if (!promptDraft) { return }
70
+
71
+ setShowModels(false)
72
  setRuns(runs + 1)
73
  setLocked(true)
74
+
75
+ scrollRef.current?.scroll({
76
+ top: 0,
77
+ behavior: 'smooth'
78
+ })
79
+
80
  startTransition(async () => {
81
+ const huggingFaceLora = selectedModel ? selectedModel.repo : "KappaNeuro/studio-ghibli-style"
82
+ const triggerWord = selectedModel ? selectedModel.trigger_word : "Studio Ghibli Style"
83
+
84
+ // now you got a read/write object
85
+ const current = new URLSearchParams(Array.from(searchParams.entries()))
86
+ current.set("prompt", promptDraft)
87
+ current.set("model", huggingFaceLora)
88
+ const search = current.toString()
89
+ router.push(`${pathname}${search ? `?${search}` : ""}`)
90
+
91
  try {
92
  console.log("starting transition, calling generateAnimation")
93
  const newAssetUrl = await generateAnimation({
 
95
  negativePrompt: "",
96
  huggingFaceLora,
97
  triggerWord,
98
+ size: "608x416", // "1024x512", // "512x512" // "320x768"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  nbFrames: 8, // if duration is 1000ms then it means 8 FPS
100
  duration: 1000, // in ms
101
  steps: 25,
102
+ })
103
  setAssetUrl(newAssetUrl)
104
 
105
  try {
 
118
  const search = current.toString()
119
  router.push(`${pathname}${search ? `?${search}` : ""}`)
120
  } catch (err) {
121
+ console.error(`not a blocker, but we failed to post to the community (reason: ${err})`)
122
  }
123
  } catch (err) {
124
  console.error(err)
 
128
  })
129
  }
130
 
131
+ useEffect(() => {
132
+ startTransition(async () => {
133
+ const models = await getSDXLModels()
134
+ setModels(models)
135
+
136
+ let defaultModel = models.find(model => model.title.toLowerCase().includes("ghibli")) || models[0]
137
+ if (defaultModel) {
138
+ setSelectedModel(defaultModel)
139
+ }
140
+ })
141
+ }, [])
142
+
143
  return (
144
  <div className={cn(
145
  `fixed inset-0 w-screen h-screen`,
 
156
  `w-full md:max-w-4xl lg:max-w-5xl xl:max-w-6xl max-h-[80vh]`,
157
  `space-y-3 md:space-y-0 md:space-x-6`,
158
  `transition-all duration-300 ease-in-out`,
 
159
  )}>
160
+ <div
161
+ ref={scrollRef}
162
+ className={cn(
163
  `flex flex-col`,
164
  `flex-grow rounded-2xl md:rounded-3xl`,
165
  `backdrop-blur-lg bg-white/40`,
166
  `border-2 border-white/10`,
167
  `items-center`,
168
  `space-y-6 md:space-y-8 lg:space-y-12 xl:space-y-16`,
169
+ `px-3 py-6 md:px-6 md:py-12 xl:px-8 xl:py-16`,
170
+ `overflow-y-scroll`,
171
+ )}
172
+ style={{
173
+ boxShadow: "inset 0 2px 4px 0 rgb(0 0 0 / 0.05)" // TODO: convert to tailwind
174
+ }}>
175
 
176
  {assetUrl ? <div
177
  className={cn(
 
250
  </div>
251
  </div>
252
  <div className="flex flex-row w-52">
253
+ <animated.button
254
+ style={{
255
+ textShadow: "0px 0px 1px #000000ab",
256
+ ...submitButtonBouncer
257
+ }}
258
+ onMouseEnter={() => setOverSubmitButton(true)}
259
+ onMouseLeave={() => setOverSubmitButton(false)}
260
+ className={cn(
261
+ `px-6 py-3`,
262
+ `rounded-full`,
263
+ `transition-all duration-300 ease-in-out`,
264
+ isLocked
265
+ ? `bg-orange-500/20 border-orange-800/10`
266
+ : `bg-sky-500/80 hover:bg-sky-400/100 border-sky-800/20`,
267
+ `text-center`,
268
+ `w-full`,
269
+ `text-2xl text-sky-50`,
270
+ `border`,
271
+ headingFont.className,
272
+ // `transition-all duration-300`,
273
+ // `hover:animate-bounce`
274
+ )}
275
+ disabled={isLocked}
276
+ onClick={handleSubmit}
277
+ >
278
+ {isLocked ? `Please wait..` : "Generate"}
279
+ </animated.button>
280
+ </div>
281
+ </div>
282
+
283
+ <div className="flex flex-col">
284
+ <div className="flex flex-row">
285
+ <h3 className={cn(
286
  headingFont.className,
287
+ "text-2xl text-sky-600 mb-4"
288
+ )}>{models.length ? "Pick a style:" : "Loading styles.."}</h3>
 
 
 
 
 
 
289
  </div>
290
+ <div className="grid grid-cols-4 sm:grid-cols-6 md:grid-cols-10 lg:grid-cols-11 xl:grid-cols-12 gap-2">
291
+ {models.map(model =>
292
+ <div key={model.repo}
293
+ className={isLocked ? '' : `cursor-pointer`}
294
+ onClick={() => {
295
+ if (!isLocked) { setSelectedModel(model) }
296
+ }}>
297
+ <img
298
+ src={
299
+ model.image.startsWith("http")
300
+ ? model.image
301
+ : `https://multimodalart-loratheexplorer.hf.space/file=${model.image}`
302
+ }
303
+ className={cn(
304
+ `transition-all duration-150 ease-in-out`,
305
+ `w-20 h-20 object-cover rounded-lg overflow-hidden`,
306
+ `border-4 border-transparent`,
307
+ `hover:border-yellow-50 hover:scale-110`,
308
+ selectedModel?.repo === model.repo
309
+ ? `scale-110 border-4 border-yellow-300 hover:border-yellow-300`
310
+ : ``
311
+ )}
312
+ ></img>
313
+ </div>)}
314
  </div>
315
  </div>
316
 
317
+ {/*<div>
318
  <p>Community creations</p>
319
  <div>
320
  <div>A</div>
 
323
  <div>D</div>
324
  <div>E</div>
325
  </div>
326
+ </div>
327
+ */}
328
  </div>
329
  </div>
330
  </div>
src/app/server/actions/community.ts CHANGED
@@ -59,9 +59,11 @@ export async function postToCommunity({
59
 
60
  const postId = uuidv4()
61
 
62
- const post: Partial<Post> = { postId, appId, prompt, assetUrl }
63
 
64
- console.table(post)
 
 
65
 
66
  const res = await fetch(`${apiUrl}/posts/${appId}`, {
67
  method: "POST",
@@ -76,18 +78,14 @@ export async function postToCommunity({
76
  // next: { revalidate: 1 }
77
  })
78
 
79
- // console.log("res:", res)
80
- // The return value is *not* serialized
81
- // You can return Date, Map, Set, etc.
82
-
83
  // Recommendation: handle errors
84
- if (res.status !== 200) {
85
  // This will activate the closest `error.js` Error Boundary
86
  throw new Error('Failed to fetch data')
87
  }
88
 
89
  const response = (await res.json()) as CreatePostResponse
90
- // console.log("response:", response)
91
  return response.post
92
  } catch (err) {
93
  const error = `failed to post to community: ${err}`
 
59
 
60
  const postId = uuidv4()
61
 
62
+ const post: Partial<Post> = { postId, appId, prompt, model, assetUrl }
63
 
64
+ console.log(`target url is: ${
65
+ `${apiUrl}/posts/${appId}`
66
+ }`)
67
 
68
  const res = await fetch(`${apiUrl}/posts/${appId}`, {
69
  method: "POST",
 
78
  // next: { revalidate: 1 }
79
  })
80
 
 
 
 
 
81
  // Recommendation: handle errors
82
+ if (res.status !== 201) {
83
  // This will activate the closest `error.js` Error Boundary
84
  throw new Error('Failed to fetch data')
85
  }
86
 
87
  const response = (await res.json()) as CreatePostResponse
88
+ console.log("response:", response)
89
  return response.post
90
  } catch (err) {
91
  const error = `failed to post to community: ${err}`
src/app/server/actions/models.ts ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "use server"
2
+
3
+ import { SDXLModel } from "@/types"
4
+
5
+ const SDXL_MODEL_DATABASE_URL = "https://huggingface.co/spaces/multimodalart/LoraTheExplorer/raw/main/sdxl_loras.json"
6
+
7
+ export async function getSDXLModels(): Promise<SDXLModel[]> {
8
+ const res = await fetch(SDXL_MODEL_DATABASE_URL, {
9
+ method: "GET",
10
+ headers: {
11
+ "Content-Type": "application/json"
12
+ },
13
+ cache: "no-store",
14
+ // we can also use this (see https://vercel.com/blog/vercel-cache-api-nextjs-cache)
15
+ // next: { revalidate: 1 }
16
+ })
17
+
18
+ const content = await res.json() as SDXLModel[]
19
+
20
+ // we only return compatible models
21
+ return content.filter(model => model.is_compatible)
22
+ }
src/lib/useCountdown.ts CHANGED
@@ -3,10 +3,12 @@
3
  import { useEffect, useRef, useState } from "react"
4
 
5
  export function useCountdown({
 
6
  timerId,
7
  durationInSec,
8
  onEnd = () => {},
9
  }: {
 
10
  timerId: string | number
11
  durationInSec: number
12
  onEnd: () => void
@@ -21,17 +23,25 @@ export function useCountdown({
21
  clearInterval(intervalRef.current)
22
  setElapsedTimeInMs(0)
23
  startedAt.current = new Date()
24
- intervalRef.current = setInterval(() => {
25
- const now = new Date()
26
- const newElapsedInMs = Math.min(durationInMs, now.getTime() - startedAt.current!.getTime())
27
- setElapsedTimeInMs(newElapsedInMs)
28
- if (elapsedTimeInMs >= durationInMs) {
29
- console.log("end of timer")
30
- clearInterval(intervalRef.current)
31
- onEnd()
32
- }
33
- }, 100)
34
- }, [timerId, durationInMs])
 
 
 
 
 
 
 
 
35
 
36
  const remainingTimeInMs = Math.max(0, durationInMs - elapsedTimeInMs)
37
 
 
3
  import { useEffect, useRef, useState } from "react"
4
 
5
  export function useCountdown({
6
+ isActive,
7
  timerId,
8
  durationInSec,
9
  onEnd = () => {},
10
  }: {
11
+ isActive: boolean
12
  timerId: string | number
13
  durationInSec: number
14
  onEnd: () => void
 
23
  clearInterval(intervalRef.current)
24
  setElapsedTimeInMs(0)
25
  startedAt.current = new Date()
26
+
27
+ if (isActive) {
28
+ intervalRef.current = setInterval(() => {
29
+ const now = new Date()
30
+ const newElapsedInMs = Math.min(durationInMs, now.getTime() - startedAt.current!.getTime())
31
+ setElapsedTimeInMs(newElapsedInMs)
32
+ if (elapsedTimeInMs > durationInMs) {
33
+ console.log("end of timer")
34
+ clearInterval(intervalRef.current)
35
+ onEnd()
36
+ }
37
+ }, 100)
38
+ }
39
+
40
+ return () => {
41
+ console.log("destruction of timer")
42
+ clearInterval(intervalRef.current)
43
+ }
44
+ }, [isActive, timerId, durationInMs])
45
 
46
  const remainingTimeInMs = Math.max(0, durationInMs - elapsedTimeInMs)
47
 
src/types.ts CHANGED
@@ -303,3 +303,14 @@ export type VideoOptions = {
303
 
304
  steps?: number
305
  }
 
 
 
 
 
 
 
 
 
 
 
 
303
 
304
  steps?: number
305
  }
306
+
307
+ export type SDXLModel = {
308
+ image: string
309
+ title: string
310
+ repo: string
311
+ trigger_word: string
312
+ weights: string
313
+ is_compatible: boolean
314
+ likes: number
315
+ downloads: number
316
+ }