sarahyurick commited on
Commit
61e8f13
·
verified ·
1 Parent(s): 0f93018

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +128 -3
README.md CHANGED
@@ -2,8 +2,133 @@
2
  tags:
3
  - model_hub_mixin
4
  - pytorch_model_hub_mixin
 
5
  ---
6
 
7
- This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
8
- - Library: [More Information Needed]
9
- - Docs: [More Information Needed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  tags:
3
  - model_hub_mixin
4
  - pytorch_model_hub_mixin
5
+ license: other
6
  ---
7
 
8
+ # Prompt Task/Complexity Classifier
9
+
10
+ # Model Overview
11
+ This is a multi-headed model which classifies English text prompts across task types and complexity dimensions. Tasks are classified across 11 common categories. Complexity is evaluated across 6 dimensions and ensembled to create an overall complexity score. Further information on the taxonomies can be found below.
12
+
13
+ This model is ready for commercial use.
14
+
15
+ **Task types:**
16
+ * Open QA: A question where the response is based on general knowledge
17
+ * Closed QA: A question where the response is based on text/data provided with the prompt
18
+ * Summarization
19
+ * Text Generation
20
+ * Code Generation
21
+ * Chatbot
22
+ * Classification
23
+ * Rewrite
24
+ * Brainstorming
25
+ * Extraction
26
+ * Other
27
+
28
+ **Complexity dimensions:**
29
+ * Overall Complexity Score: The weighted sum of the complexity dimensions. Calculated as 0.35\*CreativityScore + 0.25\*ReasoningScore + 0.15\*ConstraintScore + 0.15\*DomainKnowledgeScore + 0.05\*ContextualKnowledgeScore + 0.05\*NumberOfFewShots
30
+ * Creativity: The level of creativity needed to respond to a prompt. Score range of 0-1, with a higher score indicating more creativity.
31
+ * Reasoning: The extent of logical or cognitive effort required to respond to a prompt. Score range of 0-1, with a higher score indicating more reasoning
32
+ * Contextual Knowledge: The background information necessary to respond to a prompt. Score range of 0-1, with a higher score indicating more contextual knowledge required outside of prompt.
33
+ * Domain Knowledge: The amount of specialized knowledge or expertise within a specific subject area needed to respond to a prompt. Score range of 0-1, with a higher score indicating more domain knowledge is required.
34
+ * Constraints: The number of constraints or conditions provided with the prompt. Score range of 0-1, with a higher score indicating more constraints in the prompt.
35
+ * Number of Few Shots: The number of examples provided with the prompt. Score range of 0-n, with a higher score indicating more examples provided in the prompt.
36
+
37
+ # License
38
+ This model is released under the [NVIDIA Open Model License Agreement](https://developer.download.nvidia.com/licenses/nvidia-open-model-license-agreement-june-2024.pdf).
39
+
40
+ # References
41
+ * [DeBERTaV3: Improving DeBERTa using ELECTRA-Style Pre-Training with Gradient-Disentangled Embedding Sharing](https://arxiv.org/abs/2111.09543)
42
+ * [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://github.com/microsoft/DeBERTa)
43
+ * [Training language models to follow instructions with human feedback](https://arxiv.org/pdf/2203.02155)
44
+
45
+ # Model Architecture
46
+ The model architecture uses a DeBERTa backbone and incorporates multiple classification heads, each dedicated to a task categorization or complexity dimension. This approach enables the training of a unified network, allowing it to predict simultaneously during inference. Deberta-v3-base can theoretically handle up to 12k tokens, but default context length is set at 512 tokens.
47
+
48
+ # How to Use in NVIDIA NeMo Curator
49
+ [NeMo Curator](https://developer.nvidia.com/nemo-curator) improves generative AI model accuracy by processing text, image, and video data at scale for training and customization. It also provides pre-built pipelines for generating synthetic data to customize and evaluate generative AI systems.
50
+
51
+ The inference code for this model is available through the NeMo Curator GitHub repository. Check out this [example notebook](https://github.com/NVIDIA/NeMo-Curator/tree/main/tutorials/distributed_data_classification) to get started.
52
+
53
+ # How to Use in Transformers
54
+ To use the prompt task and complexity classifier, use the following code:
55
+
56
+ ```python
57
+ # TODO
58
+ ```
59
+
60
+ # Input & Output
61
+ ## Input
62
+ * Input Type: Text
63
+ * Input Format: String
64
+ * Input Parameters: 1D
65
+ * Other Properties Related to Input: Token Limit of 512 tokens
66
+
67
+ ## Output
68
+ * Output Type: Text/Numeric Classifications
69
+ * Output Format: String & Numeric
70
+ * Output Parameters: 1D
71
+ * Other Properties Related to Output: None
72
+
73
+ ## Examples
74
+
75
+ ```
76
+ Prompt: Write a mystery set in a small town where an everyday object goes missing, causing a ripple of curiosity and suspicion. Follow the investigation and reveal the surprising truth behind the disappearance.
77
+ ```
78
+
79
+ | Task | Complexity | Creativity | Reasoning | Contextual Knowledge | Domain Knowledge | Constraints | # of Few Shots |
80
+ |------------------|------------|------------|-----------|-----------------------|------------------|-------------|----------------|
81
+ | Text Generation | 0.472 | 0.867 | 0.056 | 0.048 | 0.226 | 0.785 | 0 |
82
+
83
+ ```
84
+ Prompt: Antibiotics are a type of medication used to treat bacterial infections. They work by either killing the bacteria or preventing them from reproducing, allowing the body’s immune system to fight off the infection. Antibiotics are usually taken orally in the form of pills, capsules, or liquid solutions, or sometimes administered intravenously. They are not effective against viral infections, and using them inappropriately can lead to antibiotic resistance. Explain the above in one sentence.
85
+ ```
86
+
87
+ | Task | Complexity | Creativity | Reasoning | Contextual Knowledge | Domain Knowledge | Constraints | # of Few Shots |
88
+ |-----------------|------------|------------|-----------|-----------------------|------------------|-------------|----------------|
89
+ | Summarization | 0.133 | 0.003 | 0.014 | 0.003 | 0.644 | 0.211 | 0 |
90
+
91
+ # Software Integration
92
+ * Runtime Engine: Python 3.10 and NeMo Curator
93
+ * Supported Hardware Microarchitecture Compatibility: NVIDIA GPU, Volta™ or higher (compute capability 7.0+), CUDA 12 (or above)
94
+ * Preferred/Supported Operating System(s): Ubuntu 22.04/20.04
95
+
96
+ # Model Version
97
+ Prompt Task and Complexity Classifier v1.1
98
+
99
+ # Training, Testing, and Evaluation Datasets
100
+ ## Training Data
101
+ * 4024 English prompts with task distribution outlined below
102
+ * Prompts were annotated by humans according to task and complexity taxonomies
103
+
104
+ Task distribution:
105
+ | Task | Count |
106
+ |------------------|-------|
107
+ | Open QA | 1214 |
108
+ | Closed QA | 786 |
109
+ | Text Generation | 480 |
110
+ | Chatbot | 448 |
111
+ | Classification | 267 |
112
+ | Summarization | 230 |
113
+ | Code Generation | 185 |
114
+ | Rewrite | 169 |
115
+ | Other | 104 |
116
+ | Brainstorming | 81 |
117
+ | Extraction | 60 |
118
+ | Total | 4024 |
119
+
120
+ ## Evaluation
121
+ For evaluation, Top-1 accuracy metric was used, which involves matching the category with the highest probability to the expected answer. Additionally, n-fold cross-validation was used to produce n different values for this metric to verify the consistency of the results. The table below displays the average of the top-1 accuracy values for the N folds calculated for each complexity dimension separately.
122
+
123
+ | | Task Accuracy | Creative Accuracy | Reasoning Accuracy | Contextual Accuracy | FewShots Accuracy | Domain Accuracy | Constraint Accuracy |
124
+ |-|------------------|-------------------|--------------------|---------------------|-------------------|-----------------|---------------------|
125
+ | Average of 10 Folds | 0.981 | 0.996 | 0.997 | 0.981 | 0.979 | 0.937 | 0.991 |
126
+
127
+ # Inference
128
+ * Engine: PyTorch
129
+ * Test Hardware: A10G
130
+
131
+ # Ethical Considerations
132
+ NVIDIA believes Trustworthy AI is a shared responsibility and we have established policies and practices to enable development for a wide array of AI applications. When downloaded or used in accordance with our terms of service, developers should work with their internal model team to ensure this model meets requirements for the relevant industry and use case and addresses unforeseen product misuse.
133
+
134
+ Please report security vulnerabilities or NVIDIA AI Concerns [here](https://www.nvidia.com/en-us/support/submit-security-vulnerability).