reformat model impl
Browse files
google_perch_lite/model.py
CHANGED
@@ -1,17 +1,19 @@
|
|
1 |
from iSparrow.sparrow_model_base import ModelBase
|
2 |
-
|
|
|
3 |
import tflite_runtime.interpreter as tflite
|
4 |
except ImportError:
|
5 |
import tensorflow.lite as tflite
|
6 |
|
7 |
-
from iSparrow import utils
|
8 |
-
from iSparrow import ModelBase
|
9 |
|
10 |
import numpy as np
|
11 |
from pathlib import Path
|
12 |
from scipy.special import softmax
|
13 |
|
14 |
-
|
|
|
15 |
"""
|
16 |
Model Implementation of a iSparrow model that uses the google perch tflite model.
|
17 |
|
@@ -19,7 +21,7 @@ class Model(ModelBase):
|
|
19 |
ModelBase (iSparrow.ModelBase): Model base class that provides the interface through which to interact with iSparrow.
|
20 |
"""
|
21 |
|
22 |
-
def __init__(self, model_path: str, num_threads: int = 1, **kwargs):
|
23 |
"""
|
24 |
__init__ Create a new model instance that uses the google perch tflite converted model.
|
25 |
|
@@ -48,7 +50,6 @@ class Model(ModelBase):
|
|
48 |
|
49 |
self.output_layer_index = output_details[1]["index"]
|
50 |
|
51 |
-
|
52 |
def predict(self, sample: np.array) -> np.array:
|
53 |
"""
|
54 |
predict Make inference about the bird species for the preprocessed data passed to this function as arguments.
|
@@ -75,7 +76,6 @@ class Model(ModelBase):
|
|
75 |
|
76 |
return confidence
|
77 |
|
78 |
-
|
79 |
@classmethod
|
80 |
def from_cfg(cls, sparrow_folder: str, cfg: dict):
|
81 |
"""
|
|
|
1 |
from iSparrow.sparrow_model_base import ModelBase
|
2 |
+
|
3 |
+
try:
|
4 |
import tflite_runtime.interpreter as tflite
|
5 |
except ImportError:
|
6 |
import tensorflow.lite as tflite
|
7 |
|
8 |
+
from iSparrow import utils
|
9 |
+
from iSparrow import ModelBase
|
10 |
|
11 |
import numpy as np
|
12 |
from pathlib import Path
|
13 |
from scipy.special import softmax
|
14 |
|
15 |
+
|
16 |
+
class Model(ModelBase):
|
17 |
"""
|
18 |
Model Implementation of a iSparrow model that uses the google perch tflite model.
|
19 |
|
|
|
21 |
ModelBase (iSparrow.ModelBase): Model base class that provides the interface through which to interact with iSparrow.
|
22 |
"""
|
23 |
|
24 |
+
def __init__(self, model_path: str, num_threads: int = 1, **kwargs):
|
25 |
"""
|
26 |
__init__ Create a new model instance that uses the google perch tflite converted model.
|
27 |
|
|
|
50 |
|
51 |
self.output_layer_index = output_details[1]["index"]
|
52 |
|
|
|
53 |
def predict(self, sample: np.array) -> np.array:
|
54 |
"""
|
55 |
predict Make inference about the bird species for the preprocessed data passed to this function as arguments.
|
|
|
76 |
|
77 |
return confidence
|
78 |
|
|
|
79 |
@classmethod
|
80 |
def from_cfg(cls, sparrow_folder: str, cfg: dict):
|
81 |
"""
|