Edit model card
YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

LIVE-BART

The LIVE-BART model was proposed in Learning to Imagine: Visually-Augmented Natural Language Generation by Tianyi Tang, Yushuo Chen, Yifan Du, Junyi Li, Wayne Xin Zhao and Ji-Rong Wen.

The detailed information and instructions can be found https://github.com/RUCAIBox/LIVE.

You should install the transformers at https://github.com/RUCAIBox/LIVE.

import torch
import torch.nn as nn
from transformers import BartForConditionalGeneration, AutoModel

class LiveModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.model = BartForConditionalGeneration.from_pretrained('RUCAIBox/live-bart-base', image_fusion_encoder=True)
        self.vision_model = AutoModel.from_pretrained('openai/clip-vit-base-patch32').vision_model
        hidden_size = self.model.config.hidden_size
        self.trans = nn.Sequential(
            nn.Linear(self.vision_model.config.hidden_size, hidden_size * 4),
            nn.ReLU(),
            nn.Linear(hidden_size * 4, hidden_size),
        )

model = LiveModel()
trans = torch.load('trans.bart.pth')
model.trans.load_state_dict(trans)

# kwargs to model.forward() and model.generate()
# input_ids [batch_size, seq_len], same to hugging face
# attention_masks [batch_size, seq_len], same to hugging face
# labels [batch_size, seq_len], same to hugging face
# image_embeds [batch_size, image_num*patch_num, image_hidden_size], should be transfered using `trans`, image_num can be the sentence num of text, patch_num and image_hidden_size are 50 and 768 for openai/clip-vit-base-patch32, respectively
# images_mask [batch_size, seq_len, image_num], this is the mask in Figure 1, 1 represents the i-th word should attend to the j-th image
# images_mask_2d [batch_size, seq_len], 1 represents the i-th word should not be visually augmented, i.e., should not be attend to any image
Downloads last month
2
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.