Spaces:
Runtime error
Runtime error
crypto-code
commited on
Commit
β’
212945c
1
Parent(s):
c4f1082
Update llama/m2ugen.py
Browse files- llama/m2ugen.py +70 -25
llama/m2ugen.py
CHANGED
@@ -231,9 +231,9 @@ class M2UGen(nn.Module):
|
|
231 |
self.music_decoder = self.args.music_decoder.lower()
|
232 |
|
233 |
# 4. prefix
|
234 |
-
self.query_layer =
|
235 |
self.query_len = 1
|
236 |
-
self.prefix_query = nn.Embedding(self.query_layer * self.query_len, self.model_args.dim).to("cuda:0")
|
237 |
|
238 |
# 5. knn
|
239 |
self.knn = knn
|
@@ -492,30 +492,52 @@ class M2UGen(nn.Module):
|
|
492 |
h = self.llama.tok_embeddings(tokens).to("cuda:0")
|
493 |
freqs_cis = self.llama.freqs_cis.to("cuda:0")
|
494 |
freqs_cis = freqs_cis[start_pos:start_pos + seqlen]
|
495 |
-
|
496 |
-
feats = torch.zeros((1, 1, 4096)).to("cuda:0")
|
497 |
-
if audio_feats is not None:
|
498 |
-
feats += audio_feats
|
499 |
-
if video_feats is not None:
|
500 |
-
feats += video_feats
|
501 |
-
if image_feats is not None:
|
502 |
-
feats += image_feats
|
503 |
|
504 |
mask = None
|
505 |
mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device="cuda:0")
|
506 |
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
|
507 |
|
508 |
music_output_embedding = []
|
509 |
-
for layer in self.llama.layers[:-
|
510 |
h = layer(h, 0, freqs_cis, mask)
|
511 |
music_output_embedding.append(h)
|
512 |
|
513 |
-
prefix_query = self.prefix_query.weight.reshape(
|
|
|
514 |
|
515 |
prefix_index = 0
|
516 |
-
|
517 |
-
|
518 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
519 |
|
520 |
h = self.llama.norm(h)
|
521 |
output = self.llama.output(h[:, -1, :])
|
@@ -523,30 +545,53 @@ class M2UGen(nn.Module):
|
|
523 |
return output.float(), torch.cat(music_output_embedding[-1:], dim=1)
|
524 |
|
525 |
def forward(self, tokens, labels, audios=None, imgs=None, videos=None, music_caption=None):
|
526 |
-
|
527 |
if audios is not None:
|
528 |
-
|
529 |
if videos is not None:
|
530 |
-
|
531 |
if imgs is not None:
|
532 |
-
|
533 |
_bsz, seqlen = tokens.shape
|
534 |
|
535 |
h = self.llama.tok_embeddings(tokens.to(self.device))
|
536 |
freqs_cis = self.llama.freqs_cis.to(h.device)
|
537 |
freqs_cis = freqs_cis[:seqlen]
|
538 |
-
mask = None
|
539 |
mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=h.device)
|
540 |
mask = torch.triu(mask, diagonal=0 + 1).type_as(h)
|
541 |
|
542 |
-
for layer in self.llama.layers[:-
|
543 |
h = layer(h, 0, freqs_cis, mask)
|
544 |
-
prefix_query = self.prefix_query.weight.reshape(
|
|
|
|
|
545 |
prefix_index = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
546 |
|
547 |
-
|
548 |
-
|
549 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
550 |
|
551 |
final_hidden = h
|
552 |
h = self.llama.norm(h)
|
|
|
231 |
self.music_decoder = self.args.music_decoder.lower()
|
232 |
|
233 |
# 4. prefix
|
234 |
+
self.query_layer = 6
|
235 |
self.query_len = 1
|
236 |
+
self.prefix_query = nn.Embedding(self.query_layer * 3 * self.query_len, self.model_args.dim).to("cuda:0")
|
237 |
|
238 |
# 5. knn
|
239 |
self.knn = knn
|
|
|
492 |
h = self.llama.tok_embeddings(tokens).to("cuda:0")
|
493 |
freqs_cis = self.llama.freqs_cis.to("cuda:0")
|
494 |
freqs_cis = freqs_cis[start_pos:start_pos + seqlen]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
495 |
|
496 |
mask = None
|
497 |
mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device="cuda:0")
|
498 |
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
|
499 |
|
500 |
music_output_embedding = []
|
501 |
+
for layer in self.llama.layers[:-3 * self.query_layer]:
|
502 |
h = layer(h, 0, freqs_cis, mask)
|
503 |
music_output_embedding.append(h)
|
504 |
|
505 |
+
prefix_query = self.prefix_query.weight.reshape(
|
506 |
+
self.query_layer * 3, 1, 4096).unsqueeze(1)
|
507 |
|
508 |
prefix_index = 0
|
509 |
+
if audio_feats is not None:
|
510 |
+
for layer in self.llama.layers[-3 * self.query_layer:-2 * self.query_layer]:
|
511 |
+
h = layer(h, 0, freqs_cis, mask, audio_feats + prefix_query[prefix_index])
|
512 |
+
music_output_embedding.append(h)
|
513 |
+
prefix_index = prefix_index + 1
|
514 |
+
else:
|
515 |
+
for layer in self.llama.layers[-3 * self.query_layer:-2 * self.query_layer]:
|
516 |
+
h = layer(h, 0, freqs_cis, mask, prefix_query[prefix_index])
|
517 |
+
music_output_embedding.append(h)
|
518 |
+
prefix_index = prefix_index + 1
|
519 |
+
|
520 |
+
if image_feats is not None:
|
521 |
+
for layer in self.llama.layers[-2 * self.query_layer:-1 * self.query_layer]:
|
522 |
+
h = layer(h, 0, freqs_cis, mask, image_feats + prefix_query[prefix_index])
|
523 |
+
music_output_embedding.append(h)
|
524 |
+
prefix_index = prefix_index + 1
|
525 |
+
else:
|
526 |
+
for layer in self.llama.layers[-2 * self.query_layer:-1 * self.query_layer]:
|
527 |
+
h = layer(h, 0, freqs_cis, mask, prefix_query[prefix_index])
|
528 |
+
music_output_embedding.append(h)
|
529 |
+
prefix_index = prefix_index + 1
|
530 |
+
|
531 |
+
if video_feats is not None:
|
532 |
+
for layer in self.llama.layers[-1 * self.query_layer:]:
|
533 |
+
h = layer(h, 0, freqs_cis, mask, video_feats + prefix_query[prefix_index])
|
534 |
+
music_output_embedding.append(h)
|
535 |
+
prefix_index = prefix_index + 1
|
536 |
+
else:
|
537 |
+
for layer in self.llama.layers[-1 * self.query_layer:]:
|
538 |
+
h = layer(h, 0, freqs_cis, mask, prefix_query[prefix_index])
|
539 |
+
music_output_embedding.append(h)
|
540 |
+
prefix_index = prefix_index + 1
|
541 |
|
542 |
h = self.llama.norm(h)
|
543 |
output = self.llama.output(h[:, -1, :])
|
|
|
545 |
return output.float(), torch.cat(music_output_embedding[-1:], dim=1)
|
546 |
|
547 |
def forward(self, tokens, labels, audios=None, imgs=None, videos=None, music_caption=None):
|
548 |
+
audio_feats, video_feats, image_feats = None, None, None
|
549 |
if audios is not None:
|
550 |
+
audio_feats = self.forward_audio({'Audio': [audios, 1]})
|
551 |
if videos is not None:
|
552 |
+
video_feats = self.forward_video({'Video': [videos, 1]})
|
553 |
if imgs is not None:
|
554 |
+
image_feats = self.forward_image({'Image': [imgs, 1]})
|
555 |
_bsz, seqlen = tokens.shape
|
556 |
|
557 |
h = self.llama.tok_embeddings(tokens.to(self.device))
|
558 |
freqs_cis = self.llama.freqs_cis.to(h.device)
|
559 |
freqs_cis = freqs_cis[:seqlen]
|
|
|
560 |
mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=h.device)
|
561 |
mask = torch.triu(mask, diagonal=0 + 1).type_as(h)
|
562 |
|
563 |
+
for layer in self.llama.layers[:-3 * self.query_layer]:
|
564 |
h = layer(h, 0, freqs_cis, mask)
|
565 |
+
prefix_query = self.prefix_query.weight.reshape(
|
566 |
+
self.query_layer * 3, 1, 4096).unsqueeze(1)
|
567 |
+
|
568 |
prefix_index = 0
|
569 |
+
if audio_feats is not None:
|
570 |
+
for layer in self.llama.layers[-3 * self.query_layer:-2 * self.query_layer]:
|
571 |
+
h = layer(h, 0, freqs_cis, mask, audio_feats + prefix_query[prefix_index])
|
572 |
+
prefix_index = prefix_index + 1
|
573 |
+
else:
|
574 |
+
for layer in self.llama.layers[-3 * self.query_layer:-2 * self.query_layer]:
|
575 |
+
h = layer(h, 0, freqs_cis, mask, prefix_query[prefix_index])
|
576 |
+
prefix_index = prefix_index + 1
|
577 |
|
578 |
+
if image_feats is not None:
|
579 |
+
for layer in self.llama.layers[-2 * self.query_layer:-1 * self.query_layer]:
|
580 |
+
h = layer(h, 0, freqs_cis, mask, image_feats + prefix_query[prefix_index])
|
581 |
+
prefix_index = prefix_index + 1
|
582 |
+
else:
|
583 |
+
for layer in self.llama.layers[-2 * self.query_layer:-1 * self.query_layer]:
|
584 |
+
h = layer(h, 0, freqs_cis, mask, prefix_query[prefix_index])
|
585 |
+
prefix_index = prefix_index + 1
|
586 |
+
|
587 |
+
if video_feats is not None:
|
588 |
+
for layer in self.llama.layers[-1 * self.query_layer:]:
|
589 |
+
h = layer(h, 0, freqs_cis, mask, video_feats + prefix_query[prefix_index])
|
590 |
+
prefix_index = prefix_index + 1
|
591 |
+
else:
|
592 |
+
for layer in self.llama.layers[-1 * self.query_layer:]:
|
593 |
+
h = layer(h, 0, freqs_cis, mask, prefix_query[prefix_index])
|
594 |
+
prefix_index = prefix_index + 1
|
595 |
|
596 |
final_hidden = h
|
597 |
h = self.llama.norm(h)
|