lnyan commited on
Commit
19712a8
1 Parent(s): eced41d
Files changed (2) hide show
  1. app.py +134 -0
  2. requirements.txt +14 -0
app.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torch import Tensor, nn
4
+ import spaces
5
+ import numpy as np
6
+ import io
7
+ import base64
8
+ from flax import nnx
9
+ import jax.numpy as jnp
10
+ from jax import Array as Tensor
11
+
12
+ from transformers import (FlaxCLIPTextModel, CLIPTokenizer, FlaxT5EncoderModel,
13
+ T5Tokenizer)
14
+
15
+
16
+ class HFEmbedder(nnx.Module):
17
+ def __init__(self, version: str, max_length: int, **hf_kwargs):
18
+ self.is_clip = version.startswith("openai")
19
+ self.max_length = max_length
20
+ self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
21
+ dtype = hf_kwargs.get("dtype", jnp.float32)
22
+ if self.is_clip:
23
+ self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
24
+ # self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
25
+ self.hf_module, params = FlaxCLIPTextModel.from_pretrained(version, _do_init=False, **hf_kwargs)
26
+ else:
27
+ self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
28
+ # self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
29
+ self.hf_module, params = FlaxT5EncoderModel.from_pretrained(version, _do_init=False,**hf_kwargs)
30
+ self.hf_module._is_initialized = True
31
+ import jax
32
+ self.hf_module.params = jax.tree_map(lambda x: jax.device_put(x, jax.devices("cuda")[0]), params)
33
+ # if dtype==jnp.bfloat16:
34
+
35
+ def tokenize(self, text: list[str]) -> Tensor:
36
+ batch_encoding = self.tokenizer(
37
+ text,
38
+ truncation=True,
39
+ max_length=self.max_length,
40
+ return_length=False,
41
+ return_overflowing_tokens=False,
42
+ padding="max_length",
43
+ return_tensors="jax",
44
+ )
45
+ return batch_encoding["input_ids"]
46
+
47
+ def __call__(self, input_ids: Tensor) -> Tensor:
48
+ # outputs = self.hf_module(
49
+ # input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
50
+ # attention_mask=None,
51
+ # output_hidden_states=False,
52
+ # )
53
+ outputs = self.hf_module(
54
+ input_ids=input_ids,
55
+ attention_mask=None,
56
+ output_hidden_states=False,
57
+ train=False,
58
+ )
59
+ return outputs[self.output_key]
60
+ # def __call__(self, text: list[str]) -> Tensor:
61
+ # batch_encoding = self.tokenizer(
62
+ # text,
63
+ # truncation=True,
64
+ # max_length=self.max_length,
65
+ # return_length=False,
66
+ # return_overflowing_tokens=False,
67
+ # padding="max_length",
68
+ # return_tensors="jax",
69
+ # )
70
+
71
+ # # outputs = self.hf_module(
72
+ # # input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
73
+ # # attention_mask=None,
74
+ # # output_hidden_states=False,
75
+ # # )
76
+ # outputs = self.hf_module(
77
+ # input_ids=batch_encoding["input_ids"],
78
+ # attention_mask=None,
79
+ # output_hidden_states=False,
80
+ # train=False,
81
+ # )
82
+ # return outputs[self.output_key]
83
+
84
+
85
+
86
+
87
+ def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
88
+ # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
89
+ return HFEmbedder("lnyan/t5-v1_1-xxl-encoder", max_length=max_length, torch_dtype=jnp.bfloat16)
90
+
91
+
92
+ def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
93
+ return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=jnp.bfloat16)
94
+
95
+ @spaces.GPU(duration=30)
96
+ def load_encoders():
97
+ is_schnell = True
98
+ t5 = load_t5("cuda", max_length=256 if is_schnell else 512)
99
+ clip = load_clip("cuda")
100
+ return t5, clip
101
+
102
+ import numpy as np
103
+ def b64(txt,vec):
104
+ buffer = io.BytesIO()
105
+ jnp.savez(buffer, txt=txt, vec=vec)
106
+ buffer.seek(0)
107
+ encoded = base64.b64encode(buffer.getvalue()).decode('utf-8')
108
+ return encoded
109
+
110
+ t5,clip=load_encoders()
111
+
112
+ @spaces.GPU(duration=10)
113
+ def convert(prompt):
114
+ if isinstance(prompt, str):
115
+ prompt = [prompt]
116
+ txt = t5.tokenize(prompt)
117
+ txt = t5(txt)
118
+ vec = clip.tokenize(prompt)
119
+ vec = clip(vec)
120
+ return b64(txt,vec)
121
+
122
+ with gr.Blocks() as demo:
123
+ gr.Markdown("""A workaround for flux-flax to fit into 40G VRAM""")
124
+ with gr.Row():
125
+ with gr.Column():
126
+ prompt = gr.Textbox(label="prompt")
127
+ convert_btn = gr.Button(value="Convert")
128
+ with gr.Column():
129
+ output = gr.Textbox(label="output")
130
+
131
+ convert_btn.click(convert, inputs=prompt, outputs=output, api_name="convert")
132
+
133
+
134
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ jax[cuda12]
2
+ flax==0.9.0
3
+ flash_attn_jax
4
+ torch
5
+ torchvision
6
+ opencv-python-headless
7
+ einops
8
+ huggingface_hub
9
+ transformers
10
+ tokenizers
11
+ sentencepiece
12
+ fire
13
+ invisible-watermark
14
+ ml-dtypes