mjwong commited on
Commit
b8f30d7
1 Parent(s): 17c50bb

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +28 -0
README.md CHANGED
@@ -21,6 +21,8 @@ Gautier Izacard, Mathilde Caron, Lucas Hosseini, Sebastian Riedel, Piotr Bojanow
21
 
22
  ## How to use the model
23
 
 
 
24
  The model can be loaded with the `zero-shot-classification` pipeline like so:
25
 
26
  ```python
@@ -44,6 +46,32 @@ candidate_labels = ['travel', 'cooking', 'dancing', 'exploration']
44
  classifier(sequence_to_classify, candidate_labels, multi_class=True)
45
  ```
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  ### Eval results
48
  The model was evaluated using the dev sets for MultiNLI and test sets for ANLI. The metric used is accuracy.
49
 
 
21
 
22
  ## How to use the model
23
 
24
+ ### With the zero-shot classification pipeline
25
+
26
  The model can be loaded with the `zero-shot-classification` pipeline like so:
27
 
28
  ```python
 
46
  classifier(sequence_to_classify, candidate_labels, multi_class=True)
47
  ```
48
 
49
+ ### With manual PyTorch
50
+
51
+ The model can also be applied on NLI tasks like so:
52
+
53
+ ```python
54
+ import torch
55
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
56
+
57
+ # device = "cuda:0" or "cpu"
58
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
59
+
60
+ model_name = 'mjwong/contriever-msmarco-mnli'
61
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
62
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
63
+
64
+ premise = "But I thought you'd sworn off coffee."
65
+ hypothesis = "I thought that you vowed to drink more coffee."
66
+
67
+ input = tokenizer(premise, hypothesis, truncation=True, return_tensors="pt")
68
+ output = model(input["input_ids"].to(device))
69
+ prediction = torch.softmax(output["logits"][0], -1).tolist()
70
+ label_names = ["entailment", "neutral", "contradiction"]
71
+ prediction = {name: round(float(pred) * 100, 2) for pred, name in zip(prediction, label_names)}
72
+ print(prediction)
73
+ ```
74
+
75
  ### Eval results
76
  The model was evaluated using the dev sets for MultiNLI and test sets for ANLI. The metric used is accuracy.
77