|
--- |
|
license: gpl-3.0 |
|
--- |
|
|
|
# Easy CLIP Embeddings alignment with Guannaco Models |
|
|
|
```python |
|
pipe = StableDiffusionPipeline(...) |
|
# llm_model = load_quant('/GuanacoOnConsumerHardware', 'guanaco7b-4bit-128g.pt', 4, 128, 0) |
|
llm_tokenizer = LlamaTokenizer.from_pretrained("JosephusCheung/Guanaco",use_fast=False,torch_dtype=torch.float16) |
|
llm_model = LlamaForCausalLM.from_pretrained("JosephusCheung/Guanaco",device_map="auto",torch_dtype=torch.float16) |
|
|
|
class LLMToCLIP(nn.Module): |
|
def __init__(self): |
|
super(LLMToCLIP, self).__init__() |
|
self.proj = nn.Linear(4096, 4096, bias=False) |
|
self.deproj = nn.Linear(4096, 768, bias=False) |
|
|
|
def forward(self, x): |
|
a = self.proj(x) |
|
b = self.deproj(a) |
|
return b |
|
|
|
llm_to_clip=LLMToCLIP() |
|
llm_to_clip.load_state_dict(torch.load("toclip.pth")) |
|
|
|
llm_embeddings = llm_model(input_ids=input_ids, output_hidden_states=True).hidden_states[-1] |
|
|
|
image = pipe(prompt_embeds=llm_to_clip(llm_embeddings)).images[0] |
|
``` |
|
|