sarahyurick commited on
Commit
cb6e443
1 Parent(s): 6b32a36

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +37 -0
README.md CHANGED
@@ -83,6 +83,43 @@ NeMo Curator improves generative AI model accuracy by processing text, image, an
83
 
84
  The inference code for this model is available through the NeMo Curator GitHub repository. Check out this [example notebook](https://github.com/NVIDIA/NeMo-Curator/tree/main/tutorials/distributed_data_classification) to get started.
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  # Input & Output
87
  ## Input
88
  - Input Type: Text
 
83
 
84
  The inference code for this model is available through the NeMo Curator GitHub repository. Check out this [example notebook](https://github.com/NVIDIA/NeMo-Curator/tree/main/tutorials/distributed_data_classification) to get started.
85
 
86
+ # How to Use in Transformers
87
+ To use the multilingual domain classifier, use the following code:
88
+ ```
89
+ import torch
90
+ from torch import nn
91
+ from transformers import AutoModel, AutoTokenizer, AutoConfig
92
+ from huggingface_hub import PyTorchModelHubMixin
93
+
94
+ class CustomModel(nn.Module, PyTorchModelHubMixin):
95
+ def __init__(self, config):
96
+ super(CustomModel, self).__init__()
97
+ self.model = AutoModel.from_pretrained(config["base_model"])
98
+ self.dropout = nn.Dropout(config["fc_dropout"])
99
+ self.fc = nn.Linear(self.model.config.hidden_size, len(config["id2label"]))
100
+
101
+ def forward(self, input_ids, attention_mask):
102
+ features = self.model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
103
+ dropped = self.dropout(features)
104
+ outputs = self.fc(dropped)
105
+ return torch.softmax(outputs[:, 0, :], dim=1)
106
+
107
+ # Setup configuration and model
108
+ config = AutoConfig.from_pretrained("nvidia/multilingual-domain-classifier")
109
+ tokenizer = AutoTokenizer.from_pretrained("nvidia/multilingual-domain-classifier")
110
+ model = CustomModel.from_pretrained("nvidia/multilingual-domain-classifier")
111
+
112
+ # Prepare and process inputs
113
+ text_samples = ["Los deportes son un dominio popular", "La política es un dominio popular"]
114
+ inputs = tokenizer(text_samples, return_tensors="pt", padding="longest", truncation=True)
115
+ outputs = model(inputs["input_ids"], inputs["attention_mask"])
116
+
117
+ # Predict and display results
118
+ predicted_classes = torch.argmax(outputs, dim=1)
119
+ predicted_domains = [config.id2label[class_idx.item()] for class_idx in predicted_classes.cpu().numpy()]
120
+ print(predicted_domains)
121
+ ```
122
+
123
  # Input & Output
124
  ## Input
125
  - Input Type: Text