MaHaWo commited on
Commit
f16709f
1 Parent(s): 3b7bf6d

work on arguments

Browse files
birdnet_custom_v2.4/model.py CHANGED
@@ -47,7 +47,7 @@ class Model(ModelBase):
47
  self,
48
  default_model_path: str = None,
49
  model_path: str = None,
50
- sensitivity: float = 1.0,
51
  num_threads: int = 1,
52
  **kwargs
53
  ):
@@ -58,8 +58,6 @@ class Model(ModelBase):
58
  classifier_model_path = str(Path(model_path) / "model.tflite")
59
  classifier_labels_path = str(Path(model_path) / "labels.txt")
60
 
61
- self.sensitivity = sensitivity
62
-
63
  # check custom classifier paths through function due to higher complexity
64
  self._check_classifier_path_integrity(
65
  classifier_model_path, classifier_labels_path
@@ -71,6 +69,7 @@ class Model(ModelBase):
71
 
72
  self.input_layer_index = None
73
  self.output_layer_index = None
 
74
 
75
  # use the super class for handling the default models and load the custom ones in this one
76
  super().__init__(
@@ -78,7 +77,6 @@ class Model(ModelBase):
78
  model_path=classifier_model_path,
79
  labels_path=classifier_labels_path,
80
  num_threads=num_threads,
81
- sensitivity=sensitivity,
82
  **kwargs
83
  )
84
 
 
47
  self,
48
  default_model_path: str = None,
49
  model_path: str = None,
50
+ sigmoid_sensitivity: float = 1.0,
51
  num_threads: int = 1,
52
  **kwargs
53
  ):
 
58
  classifier_model_path = str(Path(model_path) / "model.tflite")
59
  classifier_labels_path = str(Path(model_path) / "labels.txt")
60
 
 
 
61
  # check custom classifier paths through function due to higher complexity
62
  self._check_classifier_path_integrity(
63
  classifier_model_path, classifier_labels_path
 
69
 
70
  self.input_layer_index = None
71
  self.output_layer_index = None
72
+ self.sigmoid_sensitivity = sigmoid_sensitivity
73
 
74
  # use the super class for handling the default models and load the custom ones in this one
75
  super().__init__(
 
77
  model_path=classifier_model_path,
78
  labels_path=classifier_labels_path,
79
  num_threads=num_threads,
 
80
  **kwargs
81
  )
82
 
birdnet_default_v2.4/model.py CHANGED
@@ -16,7 +16,7 @@ class Model(ModelBase):
16
  self,
17
  model_path: str,
18
  num_threads: int = 1,
19
- sensitivity: float = 1.0,
20
  species_list_file: str = None,
21
  **kwargs
22
  ):
@@ -26,7 +26,7 @@ class Model(ModelBase):
26
  Args:
27
  model_path (str): Path to the location of the model file to be loaded
28
  num_threads (int, optional): Number of threads used for inference. Defaults to 1.
29
- sensitivity (float, optional): Parameter of the sigmoid function used to compute probabilities. Defaults to 1.0.
30
 
31
  Raises:
32
  AnalyzerConfigurationError: The model file 'model.tflite' doesn't exist at the given path.
@@ -43,10 +43,11 @@ class Model(ModelBase):
43
  model_path,
44
  labels_path,
45
  num_threads=num_threads,
46
- sensitivity=sensitivity,
47
  **kwargs
48
  )
49
 
 
 
50
  # store input and output index to not have to retrieve them each time an inference is made
51
  input_details = self.model.get_input_details()
52
 
@@ -84,7 +85,9 @@ class Model(ModelBase):
84
 
85
  prediction = self.model.get_tensor(self.output_layer_index)
86
 
87
- confidence = self._sigmoid(np.array(prediction), sensitivity=-self.sensitivity)
 
 
88
 
89
  return confidence
90
 
 
16
  self,
17
  model_path: str,
18
  num_threads: int = 1,
19
+ sigmoid_sensitivity: float = 1.0,
20
  species_list_file: str = None,
21
  **kwargs
22
  ):
 
26
  Args:
27
  model_path (str): Path to the location of the model file to be loaded
28
  num_threads (int, optional): Number of threads used for inference. Defaults to 1.
29
+ sigmoid_sensitivity (float, optional): Parameter of the sigmoid function used to compute probabilities. Defaults to 1.0.
30
 
31
  Raises:
32
  AnalyzerConfigurationError: The model file 'model.tflite' doesn't exist at the given path.
 
43
  model_path,
44
  labels_path,
45
  num_threads=num_threads,
 
46
  **kwargs
47
  )
48
 
49
+ self.sigmoid_sensitivity = sigmoid_sensitivity
50
+
51
  # store input and output index to not have to retrieve them each time an inference is made
52
  input_details = self.model.get_input_details()
53
 
 
85
 
86
  prediction = self.model.get_tensor(self.output_layer_index)
87
 
88
+ confidence = self._sigmoid(
89
+ np.array(prediction), sigmoid_sensitivity=-self.sigmoid_sensitivity
90
+ )
91
 
92
  return confidence
93
 
google_bird_classification/model.py CHANGED
@@ -7,7 +7,7 @@ import pandas as pd
7
 
8
  class Model(ModelBase):
9
 
10
- def __init__(self, model_path: str, num_threads: int = 1, species_list_file=None, **kwargs):
11
  """
12
  __init__ Create a new Model instance using the google perch model.
13
 
@@ -22,12 +22,7 @@ class Model(ModelBase):
22
  self.class_mask = None # used later
23
 
24
  super().__init__(
25
- "google_perch",
26
- model_path,
27
- labels_path,
28
- num_threads=num_threads,
29
- **kwargs
30
- # sensitivity kwarg doesn't exist here
31
  ) # num_threads doesn't do anything here.
32
 
33
  def predict(self, data: np.array):
@@ -40,10 +35,8 @@ class Model(ModelBase):
40
  list: List of (label, inferred_probability)
41
  """
42
 
43
- results = self.labels.copy()
44
-
45
  # README: this should be parallelized??
46
- logits, embeddings = self.model.infer_tf(
47
  np.array(
48
  [
49
  data,
 
7
 
8
  class Model(ModelBase):
9
 
10
+ def __init__(self, model_path: str, num_threads: int = 1, **kwargs):
11
  """
12
  __init__ Create a new Model instance using the google perch model.
13
 
 
22
  self.class_mask = None # used later
23
 
24
  super().__init__(
25
+ "google_perch", model_path, labels_path, num_threads=num_threads, **kwargs
 
 
 
 
 
26
  ) # num_threads doesn't do anything here.
27
 
28
  def predict(self, data: np.array):
 
35
  list: List of (label, inferred_probability)
36
  """
37
 
 
 
38
  # README: this should be parallelized??
39
+ logits, _ = self.model.infer_tf(
40
  np.array(
41
  [
42
  data,