Fix CXRBertOutput errors
#5
by
sambt
- opened
- modeling_cxrbert.py +4 -2
modeling_cxrbert.py
CHANGED
@@ -3,6 +3,7 @@
|
|
3 |
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
4 |
# ------------------------------------------------------------------------------------------
|
5 |
|
|
|
6 |
from typing import Any, Optional, Tuple, Union
|
7 |
|
8 |
import torch
|
@@ -16,9 +17,10 @@ from .configuration_cxrbert import CXRBertConfig
|
|
16 |
|
17 |
BERTTupleOutput = Tuple[T, T, T, T, T]
|
18 |
|
|
|
19 |
class CXRBertOutput(ModelOutput):
|
20 |
-
last_hidden_state: torch.FloatTensor
|
21 |
-
logits: torch.FloatTensor
|
22 |
cls_projected_embedding: Optional[torch.FloatTensor] = None
|
23 |
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
24 |
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
|
|
3 |
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
4 |
# ------------------------------------------------------------------------------------------
|
5 |
|
6 |
+
from dataclasses import dataclass
|
7 |
from typing import Any, Optional, Tuple, Union
|
8 |
|
9 |
import torch
|
|
|
17 |
|
18 |
BERTTupleOutput = Tuple[T, T, T, T, T]
|
19 |
|
20 |
+
@dataclass
|
21 |
class CXRBertOutput(ModelOutput):
|
22 |
+
last_hidden_state: torch.FloatTensor = None
|
23 |
+
logits: torch.FloatTensor = None
|
24 |
cls_projected_embedding: Optional[torch.FloatTensor] = None
|
25 |
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
26 |
attentions: Optional[Tuple[torch.FloatTensor]] = None
|