"use client" import { useEffect, useRef, useState, useTransition } from "react" import { useSpring, animated } from "@react-spring/web" import { usePathname, useRouter, useSearchParams } from "next/navigation" import { useToast } from "@/components/ui/use-toast" import { cn } from "@/lib/utils" import { headingFont } from "@/app/interface/fonts" import { useCharacterLimit } from "@/lib/useCharacterLimit" import { generateAnimation } from "@/app/server/actions/animation" import { interpolateVideo } from "@/app/server/actions/interpolation" import { getLatestPosts, getPost, postToCommunity } from "@/app/server/actions/community" import { getSDXLModels } from "@/app/server/actions/models" import { HotshotImageInferenceSize, Post, QualityLevel, QualityOption, SDXLModel } from "@/types" import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip" import { TooltipProvider } from "@radix-ui/react-tooltip" import { isRateLimitError } from "@/app/server/utils/isRateLimitError" import { useCountdown } from "@/lib/useCountdown" import { Countdown } from "../countdown" const qualityOptions = [ { level: "low", label: "Low (~ 30 sec)" }, { level: "medium", label: "Medium (~90 secs)" } ] as QualityOption[] type Stage = "generate" | "interpolate" | "finished" export function Generate() { const router = useRouter() const pathname = usePathname() const searchParams = useSearchParams() const searchParamsEntries = searchParams ? Array.from(searchParams.entries()) : [] const [_isPending, startTransition] = useTransition() const scrollRef = useRef(null) const videoRef = useRef(null) const [isLocked, setLocked] = useState(false) const [promptDraft, setPromptDraft] = useState("") const [assetUrl, setAssetUrl] = useState("") const [isOverSubmitButton, setOverSubmitButton] = useState(false) const [models, setModels] = useState([]) const [selectedModel, setSelectedModel] = useState() const [runs, setRuns] = useState(0) const runsRef = useRef(0) const [showModels, setShowModels] = useState(true) const [communityRoll, setCommunityRoll] = useState([]) const [stage, setStage] = useState("generate") const [qualityLevel, setQualityLevel] = useState("low") const { toast } = useToast() const { progressPercent, remainingTimeInSec } = useCountdown({ isActive: isLocked, timerId: runs, // everytime we change this, the timer will reset durationInSec: /*stage === "interpolate" ? 30 :*/ 90, // it usually takes 40 seconds, but there might be lag onEnd: () => {} }) const { shouldWarn, colorClass, nbCharsUsed, nbCharsLimits } = useCharacterLimit({ value: promptDraft, nbCharsLimits: 70, warnBelow: 10, }) const submitButtonBouncer = useSpring({ transform: isOverSubmitButton ? 'scale(1.05)' : 'scale(1.0)', boxShadow: isOverSubmitButton ? `0px 5px 15px 0px rgba(0, 0, 0, 0.05)` : `0px 0px 0px 0px rgba(0, 0, 0, 0.05)`, loop: true, config: { tension: 300, friction: 10, }, }) const handleSubmit = () => { if (isLocked) { return } if (!promptDraft) { return } setShowModels(false) setRuns(runsRef.current + 1) setLocked(true) setStage("generate") scrollRef.current?.scroll({ top: 0, behavior: 'smooth' }) startTransition(async () => { const huggingFaceLora = selectedModel ? selectedModel.repo.trim() : "KappaNeuro/studio-ghibli-style" const triggerWord = selectedModel ? selectedModel.trigger_word : "Studio Ghibli Style" // now you got a read/write object const current = new URLSearchParams(searchParamsEntries) current.set("prompt", promptDraft) current.set("model", huggingFaceLora) const search = current.toString() router.push(`${pathname}${search ? `?${search}` : ""}`) const size: HotshotImageInferenceSize = "608x416" // 608x416 @ 25 steps -> 32 seconds const steps = qualityLevel === "low" ? 30 : 45 let key = "" try { const res = await fetch("/api/get-key", { method: "GET", headers: { Accept: "application/json", "Content-Type": "application/json", }, cache: 'no-store', }) key = await res.text() } catch (err) { console.error("failed to get key, but this is not a blocker") } const params = { positivePrompt: promptDraft, negativePrompt: "", huggingFaceLora, triggerWord, nbFrames: 10, // if duration is 1000ms then it means 8 FPS duration: 1000, // in ms steps, size, key } let rawAssetUrl = "" try { // console.log("starting transition, calling generateAnimation") rawAssetUrl = await generateAnimation(params) if (!rawAssetUrl) { throw new Error("invalid asset url") } setAssetUrl(rawAssetUrl) } catch (err) { // check the rate limit if (isRateLimitError(err)) { console.error("error, too many requests") toast({ title: "You can generate only one video per minute 👀", description: "Please wait a bit before trying again 🤗", }) setLocked(false) return } else { toast({ title: "We couldn't generate your video 👀", description: "We are probably over capacity, but you can try again 🤗", }) } console.log("generation failed! probably just a Gradio failure, so let's just run the round robin again!") try { rawAssetUrl = await generateAnimation(params) } catch (err) { // check the rate limit if (isRateLimitError(err)) { console.error("error, too many requests") toast({ title: "Error: the free server is over capacity 👀", description: "You can generate 2 videos per minute 🤗 Please try again in a moment!", }) setLocked(false) return } console.error(`generation failed again! ${err}`) } } if (!rawAssetUrl) { console.log("failed to generate the video, aborting") setLocked(false) return } setAssetUrl(rawAssetUrl) let assetUrl = rawAssetUrl setStage("interpolate") setRuns(runsRef.current + 1) try { assetUrl = await interpolateVideo(rawAssetUrl) if (!assetUrl) { throw new Error("invalid interpolated asset url") } setAssetUrl(assetUrl) } catch (err) { console.log(`failed to interpolate the video, but this is not a blocker: ${err}`) } setLocked(false) setStage("generate") if (process.env.NEXT_PUBLIC_ENABLE_COMMUNITY_SHARING !== "true") { return } try { const post = await postToCommunity({ prompt: promptDraft, model: huggingFaceLora, assetUrl, }) console.log("successfully submitted to the community!", post) // now you got a read/write object const current = new URLSearchParams(searchParamsEntries) current.set("postId", post.postId.trim()) current.set("prompt", post.prompt.trim()) current.set("model", post.model.trim()) const search = current.toString() router.push(`${pathname}${search ? `?${search}` : ""}`) } catch (err) { console.error(`not a blocker, but we failed to post to the community (reason: ${err})`) } }) } useEffect(() => { startTransition(async () => { const models = await getSDXLModels() setModels(models) const defaultModel = models.find(model => model.repo.toLowerCase().includes("ghibli")) || models[0] if (defaultModel) { setSelectedModel(defaultModel) } // now we load URL params const current = new URLSearchParams(searchParamsEntries) // URL query params const existingPostId = current.get("postId") || "" const existingPrompt = current.get("prompt")?.trim() || "" const existingModelName = current.get("model")?.toLowerCase().trim() || "" // if and only if we don't have a post id, then we look at the other query params if (existingPrompt) { setPromptDraft(existingPrompt) } if (existingModelName) { let existingModel = models.find(model => { return ( model.repo.toLowerCase().trim().includes(existingModelName) || model.title.toLowerCase().trim().includes(existingModelName) ) }) if (existingModel) { setSelectedModel(existingModel) } } // if we have a post id, then we use that to override all the previous values if (existingPostId) { try { const post = await getPost(existingPostId) if (post.assetUrl) { setAssetUrl(post.assetUrl) } if (post.prompt) { setPromptDraft(post.prompt) } if (post.model) { const nameToFind = post.model.toLowerCase().trim() const existingModel = models.find(model => { return ( model.repo.toLowerCase().trim().includes(nameToFind) || model.title.toLowerCase().trim().includes(nameToFind) ) }) if (existingModel) { setSelectedModel(existingModel) } } } catch (err) { console.error(`failed to load the community post (${err})`) } } }) }, []) useEffect(() => { startTransition(async () => { const posts = await getLatestPosts({ maxNbPosts: 32, shuffle: true, }) if (posts?.length) { setCommunityRoll(posts) } }) }, []) const handleSelectCommunityPost = (post: Post) => { if (isLocked) { return } scrollRef.current?.scroll({ top: 0, behavior: 'smooth' }) // now you got a read/write object const current = new URLSearchParams(searchParamsEntries) current.set("postId", post.postId.trim()) current.set("prompt", post.prompt.trim()) current.set("model", post.model.trim()) const search = current.toString() router.push(`${pathname}${search ? `?${search}` : ""}`) if (post.assetUrl) { setAssetUrl(post.assetUrl) } if (post.prompt) { setPromptDraft(post.prompt) } if (post.model) { const nameToFind = post.model.toLowerCase().trim() const existingModel = models.find(model => { return ( model.repo.toLowerCase().trim().includes(nameToFind) || model.title.toLowerCase().trim().includes(nameToFind) ) }) if (existingModel) { setSelectedModel(existingModel) } } } const handleClickPlay = () => { videoRef.current?.play() } return (
{isLocked ? : null}
{assetUrl ?
{assetUrl.startsWith("data:video/mp4") || assetUrl.endsWith(".mp4") ?
: null}
setPromptDraft(e.target.value)} onKeyDown={({ key }) => { if (key === 'Enter') { if (!isLocked) { handleSubmit() } } }} disabled={isLocked} />
{nbCharsUsed} / {nbCharsLimits}
setOverSubmitButton(true)} onMouseLeave={() => setOverSubmitButton(false)} className={cn( `px-6 py-3`, `rounded-full`, `transition-all duration-300 ease-in-out`, isLocked ? `bg-orange-500/20 border-orange-800/10` : `bg-sky-500/80 hover:bg-sky-400/100 border-sky-800/20`, `text-center`, `w-full`, `text-2xl text-sky-50`, `border`, headingFont.className, // `transition-all duration-300`, // `hover:animate-bounce` )} disabled={isLocked} onClick={handleSubmit} > {isLocked ? (stage === "generate" ? `Generating..` : `Smoothing..`) : "Generate" }
{/*

Generation will take about 32 seconds

*/}

{models.length ? `You selected:` : ""}

{models.length ? `${(selectedModel?.title || "").replaceAll("-", " ")}` : "Loading styles.."}

{models.map(model =>
{ if (!isLocked) { setSelectedModel(model) } }}>
{!isLocked &&

{model.title}

}
)}

{communityRoll.length ? "Random community clips:" : "Loading community roll.."}

{communityRoll.map(post =>
{ handleSelectCommunityPost(post) }}>
{!isLocked &&

{post.prompt}

}
)}
) }