Update imagic.py
Browse files
imagic.py
CHANGED
@@ -5,7 +5,7 @@
|
|
5 |
import inspect
|
6 |
import warnings
|
7 |
from typing import List, Optional, Union
|
8 |
-
|
9 |
import numpy as np
|
10 |
import torch
|
11 |
import torch.nn.functional as F
|
@@ -236,7 +236,8 @@ class ImagicStableDiffusionPipeline(DiffusionPipeline):
|
|
236 |
text_embeddings_orig = text_embeddings.clone()
|
237 |
|
238 |
# Initialize the optimizer
|
239 |
-
|
|
|
240 |
[text_embeddings], # only optimize the embeddings
|
241 |
lr=embedding_learning_rate,
|
242 |
)
|
|
|
5 |
import inspect
|
6 |
import warnings
|
7 |
from typing import List, Optional, Union
|
8 |
+
import bitsandbytes as bnb
|
9 |
import numpy as np
|
10 |
import torch
|
11 |
import torch.nn.functional as F
|
|
|
236 |
text_embeddings_orig = text_embeddings.clone()
|
237 |
|
238 |
# Initialize the optimizer
|
239 |
+
|
240 |
+
optimizer = bnb.optim.Adam8bit(
|
241 |
[text_embeddings], # only optimize the embeddings
|
242 |
lr=embedding_learning_rate,
|
243 |
)
|