Dr. Richard Zinck
Erros msg
5b273b0
"""Module to analyze audio samples.
"""
import argparse
import datetime
import json
import operator
import os
import sys
from multiprocessing import Pool, freeze_support
import numpy as np
import audio
import config as cfg
import model
import species
import utils
import subprocess
import pathlib
def load_codes():
"""Loads the eBird codes.
Returns:
A dictionary containing the eBird codes.
"""
with open(cfg.CODES_FILE, "r") as cfile:
codes = json.load(cfile)
return codes
def save_result_file(r: dict[str, list], path: str, afile_path: str):
"""Saves the results to the hard drive.
Args:
r: The dictionary with {segment: scores}.
path: The path where the result should be saved.
afile_path: The path to audio file.
"""
# Make folder if it doesn't exist
if os.path.dirname(path):
os.makedirs(os.path.dirname(path), exist_ok=True)
# Selection table
out_string = ""
if cfg.RESULT_TYPE == "table":
# Raven selection header
header = "Selection\tView\tChannel\tBegin Time (s)\tEnd Time (s)\tSpecies Code\tCommon Name\tConfidence\n"
selection_id = 0
# Write header
out_string += header
# Extract valid predictions for every timestamp
for timestamp in get_sorted_timestamps(r):
rstring = ""
start, end = timestamp.split("-", 1)
for c in r[timestamp]:
if c[1] > cfg.MIN_CONFIDENCE and (not cfg.SPECIES_LIST or c[0] in cfg.SPECIES_LIST):
selection_id += 1
label = cfg.TRANSLATED_LABELS[cfg.LABELS.index(c[0])]
rstring += "{}\tSpectrogram 1\t1\t{}\t{}\t{}\t{}\t{:.4f}\n".format(
selection_id,
start,
end,
cfg.CODES[c[0]] if c[0] in cfg.CODES else c[0],
label.split("_", 1)[-1],
c[1],
)
# Write result string to file
out_string += rstring
elif cfg.RESULT_TYPE == "audacity":
# Audacity timeline labels
for timestamp in get_sorted_timestamps(r):
rstring = ""
for c in r[timestamp]:
if c[1] > cfg.MIN_CONFIDENCE and (not cfg.SPECIES_LIST or c[0] in cfg.SPECIES_LIST):
label = cfg.TRANSLATED_LABELS[cfg.LABELS.index(c[0])]
rstring += "{}\t{}\t{:.4f}\n".format(timestamp.replace("-", "\t"), label.replace("_", ", "), c[1])
# Write result string to file
out_string += rstring
elif cfg.RESULT_TYPE == "r":
# Output format for R
header = ("filepath,start,end,scientific_name,common_name,confidence,lat,lon,week,"
"overlap,sensitivity,min_conf,species_list,model")
out_string += header
for timestamp in get_sorted_timestamps(r):
rstring = ""
start, end = timestamp.split("-", 1)
for c in r[timestamp]:
if c[1] > cfg.MIN_CONFIDENCE and (not cfg.SPECIES_LIST or c[0] in cfg.SPECIES_LIST):
label = cfg.TRANSLATED_LABELS[cfg.LABELS.index(c[0])]
rstring += "\n{},{},{},{},{},{:.4f},{:.4f},{:.4f},{},{},{},{},{},{}".format(
afile_path,
start,
end,
label.split("_", 1)[0],
label.split("_", 1)[-1],
c[1],
cfg.LATITUDE,
cfg.LONGITUDE,
cfg.WEEK,
cfg.SIG_OVERLAP,
(1.0 - cfg.SIGMOID_SENSITIVITY) + 1.0,
cfg.MIN_CONFIDENCE,
cfg.SPECIES_LIST_FILE,
os.path.basename(cfg.MODEL_PATH),
)
# Write result string to file
out_string += rstring
elif cfg.RESULT_TYPE == "kaleidoscope":
# Output format for kaleidoscope
header = ("INDIR,FOLDER,IN FILE,OFFSET,DURATION,scientific_name,"
"common_name,confidence,lat,lon,week,overlap,sensitivity")
out_string += header
folder_path, filename = os.path.split(afile_path)
parent_folder, folder_name = os.path.split(folder_path)
for timestamp in get_sorted_timestamps(r):
rstring = ""
start, end = timestamp.split("-", 1)
for c in r[timestamp]:
if c[1] > cfg.MIN_CONFIDENCE and (not cfg.SPECIES_LIST or c[0] in cfg.SPECIES_LIST):
label = cfg.TRANSLATED_LABELS[cfg.LABELS.index(c[0])]
rstring += "\n{},{},{},{},{},{},{},{:.4f},{:.4f},{:.4f},{},{},{}".format(
parent_folder.rstrip("/"),
folder_name,
filename,
start,
float(end) - float(start),
label.split("_", 1)[0],
label.split("_", 1)[-1],
c[1],
cfg.LATITUDE,
cfg.LONGITUDE,
cfg.WEEK,
cfg.SIG_OVERLAP,
(1.0 - cfg.SIGMOID_SENSITIVITY) + 1.0,
)
# Write result string to file
out_string += rstring
else:
# CSV output file
header = "Start (s),End (s),Scientific name,Common name,Confidence\n"
# Write header
out_string += header
for timestamp in get_sorted_timestamps(r):
rstring = ""
for c in r[timestamp]:
start, end = timestamp.split("-", 1)
if c[1] > cfg.MIN_CONFIDENCE and (not cfg.SPECIES_LIST or c[0] in cfg.SPECIES_LIST):
label = cfg.TRANSLATED_LABELS[cfg.LABELS.index(c[0])]
rstring += "{},{},{},{},{:.4f}\n".format(start, end, label.split("_", 1)[0],
label.split("_", 1)[-1], c[1])
# Write result string to file
out_string += rstring
# Save as file
with open(path, "w", encoding="utf-8") as rfile:
rfile.write(out_string)
return out_string
def get_sorted_timestamps(results: dict[str, list]):
"""Sorts the results based on the segments.
Args:
results: The dictionary with {segment: scores}.
Returns:
Returns the sorted list of segments and their scores.
"""
return sorted(results, key=lambda t: float(t.split("-", 1)[0]))
def get_raw_audio_from_file(fpath: str):
"""Reads an audio file.
Reads the file and splits the signal into chunks.
Args:
fpath: Path to the audio file.
Returns:
The signal split into a list of chunks.
"""
# Open file
try:
sig, rate = audio.openAudioFile(fpath, cfg.SAMPLE_RATE)
except Exception as ex:
print(f"Error: Cannot open audio file at load {fpath}", flush=True)
utils.writeErrorLog(ex)
chunks = None
# Split into raw audio chunks
try:
chunks = audio.splitSignal(sig, cfg.SAMPLE_RATE, cfg.SIG_LENGTH, cfg.SIG_OVERLAP, cfg.SIG_MINLEN)
except Exception as ex:
print(f"Error: Cannot chunk audio file {fpath}", flush=True)
utils.writeErrorLog(ex)
return chunks
def predict(samples):
"""Predicts the classes for the given samples.
Args:
samples: Samples to be predicted.
Returns:
The prediction scores.
"""
# Prepare sample and pass through model
data = np.array(samples, dtype="float32")
prediction = model.predict(data)
# Logits or sigmoid activations?
if cfg.APPLY_SIGMOID:
prediction = model.flat_sigmoid(np.array(prediction), sensitivity=-cfg.SIGMOID_SENSITIVITY)
return prediction
def analyze_file(item):
"""Analyzes a file.
Predicts the scores for the file and saves the results.
Args:
item: Tuple containing (file path, config)
Returns:
The `True` if the file was analyzed successfully.
"""
# Get file path and restore cfg
fpath: str = item[0]
cfg.set_config(item[1])
# Start time
start_time = datetime.datetime.now()
# Status
print(f"Analyzing {fpath}", flush=True)
# try:
# Open audio file and split into 3-second chunks
chunks = get_raw_audio_from_file(fpath)
# If no chunks, show error and skip
#except Exception as ex:
# print(f"Error: Cannot open audio file {fpath}", flush=True)
# utils.writeErrorLog(ex)
# return False
# Process each chunk
try:
start, end = 0, cfg.SIG_LENGTH
results = {}
samples = []
timestamps = []
for chunk_index, chunk in enumerate(chunks):
# Add to batch
samples.append(chunk)
timestamps.append([start, end])
# Advance start and end
start += cfg.SIG_LENGTH - cfg.SIG_OVERLAP
end = start + cfg.SIG_LENGTH
# Check if batch is full or last chunk
if len(samples) < cfg.BATCH_SIZE and chunk_index < len(chunks) - 1:
continue
# Predict
prediction = predict(samples)
# Add to results
for i in range(len(samples)):
# Get timestamp
s_start, s_end = timestamps[i]
# Get prediction
pred = prediction[i]
# Assign scores to labels
p_labels = zip(cfg.LABELS, pred)
# Sort by score
p_sorted = sorted(p_labels, key=operator.itemgetter(1), reverse=True)
# Store top 5 results and advance indices
results[str(s_start) + "-" + str(s_end)] = p_sorted
# Clear batch
samples = []
timestamps = []
except Exception as ex:
# Write error log
print(f"Error: Cannot analyze audio file {fpath}.\n", flush=True)
utils.writeErrorLog(ex)
return False
# Save as selection table
try:
# We have to check if output path is a file or directory
if not cfg.OUTPUT_PATH.rsplit(".", 1)[-1].lower() in ["txt", "csv"]:
rpath = fpath.replace(cfg.INPUT_PATH, "")
rpath = rpath[1:] if rpath[0] in ["/", "\\"] else rpath
# Make target directory if it doesn't exist
rdir = os.path.join(cfg.OUTPUT_PATH, os.path.dirname(rpath))
os.makedirs(rdir, exist_ok=True)
if cfg.RESULT_TYPE == "table":
rtype = "bat.selection.table.txt"
elif cfg.RESULT_TYPE == "audacity":
rtype = ".bat.results.txt"
else:
rtype = ".bat.results.csv"
out_string = save_result_file(results, os.path.join(cfg.OUTPUT_PATH, rpath.rsplit(".", 1)[0] + rtype), fpath)
else:
out_string = save_result_file(results, cfg.OUTPUT_PATH, fpath)
# Save as file
with open(cfg.OUTPUT_PATH + "Results.csv", "a", encoding="utf-8") as rfile:
postString = out_string.split("\n", 1)[1]
# rfile.write(fpath.join(postString.splitlines(True)))
rfile.write(postString)
except Exception as ex:
# Write error log
print(f"Error: Cannot save result for {fpath}.\n", flush=True)
utils.writeErrorLog(ex)
return False
delta_time = (datetime.datetime.now() - start_time).total_seconds()
print("Finished {} in {:.2f} seconds".format(fpath, delta_time), flush=True)
return True
def set_analysis_location():
if args.area not in ["Bavaria", "Sweden", "EU", "Scotland", "UK", "USA","MarinCounty"]:
exit(code="Unknown location option.")
else:
args.lat = -1
args.lon = -1
# args.locale = "en"
if args.area == "Bavaria":
cfg.CUSTOM_CLASSIFIER = cfg.BAT_CLASSIFIER_LOCATION + "/BattyBirdNET-Bavaria-144kHz.tflite"
cfg.LABELS_FILE = cfg.BAT_CLASSIFIER_LOCATION + "/BattyBirdNET-Bavaria-144kHz_Labels.txt"
cfg.LABELS = utils.readLines(cfg.LABELS_FILE)
args.locale = "de"
elif args.area == "EU":
cfg.CUSTOM_CLASSIFIER = cfg.BAT_CLASSIFIER_LOCATION + "/BattyBirdNET-EU-144kHz.tflite"
cfg.LABELS_FILE = cfg.BAT_CLASSIFIER_LOCATION + "/BattyBirdNET-EU-144kHz_Labels.txt"
cfg.LABELS = utils.readLines(cfg.LABELS_FILE)
elif args.area == "Sweden":
cfg.CUSTOM_CLASSIFIER = cfg.BAT_CLASSIFIER_LOCATION + "/BattyBirdNET-Sweden-144kHz.tflite"
cfg.LABELS_FILE = cfg.BAT_CLASSIFIER_LOCATION + "/BattyBirdNET-Sweden-144kHz_Labels.txt"
cfg.LABELS = utils.readLines(cfg.LABELS_FILE)
args.locale = "se"
elif args.area == "Scotland":
cfg.CUSTOM_CLASSIFIER = cfg.BAT_CLASSIFIER_LOCATION + "/BattyBirdNET-Scotland-144kHz.tflite"
cfg.LABELS_FILE = cfg.BAT_CLASSIFIER_LOCATION + "/BattyBirdNET-Scotland-144kHz_Labels.txt"
cfg.LABELS = utils.readLines(cfg.LABELS_FILE)
elif args.area == "UK":
cfg.CUSTOM_CLASSIFIER = cfg.BAT_CLASSIFIER_LOCATION + "/BattyBirdNET-UK-144kHz.tflite"
cfg.LABELS_FILE = cfg.BAT_CLASSIFIER_LOCATION + "/BattyBirdNET-UK-144kHz_Labels.txt"
cfg.LABELS = utils.readLines(cfg.LABELS_FILE)
elif args.area == "USA":
cfg.CUSTOM_CLASSIFIER = cfg.BAT_CLASSIFIER_LOCATION + "/BattyBirdNET-USA-144kHz.tflite"
cfg.LABELS_FILE = cfg.BAT_CLASSIFIER_LOCATION + "/BattyBirdNET-USA-144kHz_Labels.txt"
cfg.LABELS = utils.readLines(cfg.LABELS_FILE)
elif args.area == "MarinCounty":
cfg.CUSTOM_CLASSIFIER = cfg.BAT_CLASSIFIER_LOCATION + "/BattyBirdNET-MarinCounty-144kHz.tflite"
cfg.LABELS_FILE = cfg.BAT_CLASSIFIER_LOCATION + "/BattyBirdNET-MarinCounty-144kHz_Labels.txt"
cfg.LABELS = utils.readLines(cfg.LABELS_FILE)
else:
cfg.CUSTOM_CLASSIFIER = None
def set_paths():
# Set paths relative to script path (requested in #3)
script_dir = os.path.dirname(os.path.abspath(sys.argv[0]))
cfg.MODEL_PATH = os.path.join(script_dir, cfg.MODEL_PATH)
cfg.LABELS_FILE = os.path.join(script_dir, cfg.LABELS_FILE)
cfg.TRANSLATED_LABELS_PATH = os.path.join(script_dir, cfg.TRANSLATED_LABELS_PATH)
cfg.MDATA_MODEL_PATH = os.path.join(script_dir, cfg.MDATA_MODEL_PATH)
cfg.CODES_FILE = os.path.join(script_dir, cfg.CODES_FILE)
cfg.ERROR_LOG_FILE = os.path.join(script_dir, cfg.ERROR_LOG_FILE)
cfg.BAT_CLASSIFIER_LOCATION = os.path.join(script_dir, cfg.BAT_CLASSIFIER_LOCATION)
cfg.INPUT_PATH = args.i
cfg.OUTPUT_PATH = args.o
def set_custom_classifier():
if args.classifier is None:
return
cfg.CUSTOM_CLASSIFIER = args.classifier # we treat this as absolute path, so no need to join with dirname
cfg.LABELS_FILE = args.classifier.replace(".tflite", "_Labels.txt") # same for labels file
cfg.LABELS = utils.readLines(cfg.LABELS_FILE)
args.lat = -1
args.lon = -1
# args.locale = "en"
def add_parser_arguments():
parser.add_argument("--area",
default="EU",
help="Location. Values in ['Bavaria', 'EU', 'Sweden','Scotland', 'UK', 'USA', 'MarinCounty']. "
"Defaults to Bavaria.")
parser.add_argument("--sensitivity",
type=float,
default=1.0,
help="Detection sensitivity; Higher values result in higher sensitivity. "
"Values in [0.5, 1.5]. Defaults to 1.0."
)
parser.add_argument("--min_conf",
type=float,
default=0.7,
help="Minimum confidence threshold. Values in [0.01, 0.99]. Defaults to 0.1.")
parser.add_argument("--overlap",
type=float,
default=0.0,
help="Overlap of prediction segments. Values in [0.0, 2.9]. Defaults to 0.0."
)
parser.add_argument("--rtype",
default="csv",
help="Specifies output format. Values in ['table', 'audacity', 'r', 'kaleidoscope', 'csv']. "
"Defaults to 'csv' (Raven selection table)."
)
parser.add_argument("--threads",
type=int,
default=4,
help="Number of CPU threads.")
parser.add_argument("--batchsize",
type=int,
default=1,
help="Number of samples to process at the same time. Defaults to 1."
)
parser.add_argument("--sf_thresh",
type=float,
default=0.03,
help="Minimum species occurrence frequency threshold for location filter. "
"Values in [0.01, 0.99]. Defaults to 0.03."
)
parser.add_argument("--segment",
default="off",
help="Generate audio files containing the detected segments. "
)
parser.add_argument("--spectrum",
default="off",
help="Generate mel spectrograms files containing the detected segments. "
)
parser.add_argument("--i",
default=cfg.INPUT_PATH_SAMPLES, # "put-your-files-here/",
help="Path to input file or folder. If this is a file, --o needs to be a file too.")
parser.add_argument("--o",
default=cfg.OUTPUT_PATH_SAMPLES,
help="Path to output file or folder. If this is a file, --i needs to be a file too.")
parser.add_argument("--classifier",
default=None,
help="Path to custom trained classifier. Defaults to None. "
"If set, --lat, --lon and --locale are ignored."
)
parser.add_argument("--slist",
default="",
help='Path to species list file or folder. If folder is provided, species list needs to be '
'named "species_list.txt". If lat and lon are provided, this list will be ignored.'
)
parser.add_argument("--lat",
type=float,
default=-1,
help="DISABLED. Set -1 to ignore.")
parser.add_argument("--lon",
type=float,
default=-1,
help="DISABLED. Set -1 to ignore.")
parser.add_argument("--week",
type=int,
default=-1,
help="DISABLED. Set -1 for year-round species list."
)
parser.add_argument("--locale",
default="en",
help="DISABLED. Defaults to 'en'."
)
def load_ebird_codes():
cfg.CODES = load_codes()
cfg.LABELS = utils.readLines(cfg.LABELS_FILE)
def load_species_list():
cfg.LATITUDE, cfg.LONGITUDE, cfg.WEEK = args.lat, args.lon, args.week
cfg.LOCATION_FILTER_THRESHOLD = max(0.01, min(0.99, float(args.sf_thresh)))
script_dir = os.path.dirname(os.path.abspath(sys.argv[0]))
if cfg.LATITUDE == -1 and cfg.LONGITUDE == -1:
if not args.slist:
cfg.SPECIES_LIST_FILE = None
else:
cfg.SPECIES_LIST_FILE = os.path.join(script_dir, args.slist)
if os.path.isdir(cfg.SPECIES_LIST_FILE):
cfg.SPECIES_LIST_FILE = os.path.join(cfg.SPECIES_LIST_FILE, "species_list.txt")
cfg.SPECIES_LIST = utils.readLines(cfg.SPECIES_LIST_FILE)
else:
cfg.SPECIES_LIST_FILE = None
cfg.SPECIES_LIST = species.getSpeciesList(cfg.LATITUDE, cfg.LONGITUDE, cfg.WEEK, cfg.LOCATION_FILTER_THRESHOLD)
if not cfg.SPECIES_LIST:
print(f"Species list contains {len(cfg.LABELS)} species")
else:
print(f"Species list contains {len(cfg.SPECIES_LIST)} species")
def parse_input_files():
if os.path.isdir(cfg.INPUT_PATH):
cfg.FILE_LIST = utils.collect_audio_files(cfg.INPUT_PATH)
print(f"Found {len(cfg.FILE_LIST)} files to analyze")
else:
cfg.FILE_LIST = [cfg.INPUT_PATH]
def set_analysis_parameters():
cfg.MIN_CONFIDENCE = max(0.01, min(0.99, float(args.min_conf)))
cfg.SIGMOID_SENSITIVITY = max(0.5, min(1.0 - (float(args.sensitivity) - 1.0), 1.5))
cfg.SIG_OVERLAP = max(0.0, min(2.9, float(args.overlap)))
cfg.BATCH_SIZE = max(1, int(args.batchsize))
def set_hardware_parameters():
if os.path.isdir(cfg.INPUT_PATH):
cfg.CPU_THREADS = max(1, int(args.threads))
cfg.TFLITE_THREADS = 1
else:
cfg.CPU_THREADS = 1
cfg.TFLITE_THREADS = max(1, int(args.threads))
def load_translated_labels():
cfg.TRANSLATED_LABELS_PATH = cfg.TRANSLATED_BAT_LABELS_PATH
lfile = os.path.join(cfg.TRANSLATED_LABELS_PATH,
os.path.basename(cfg.LABELS_FILE).replace(".txt", "_{}.txt".format(args.locale))
)
if args.locale not in ["en"] and os.path.isfile(lfile):
cfg.TRANSLATED_LABELS = utils.readLines(lfile)
else:
cfg.TRANSLATED_LABELS = cfg.LABELS
def check_result_type():
cfg.RESULT_TYPE = args.rtype.lower()
if cfg.RESULT_TYPE not in ["table", "audacity", "r", "kaleidoscope", "csv"]:
cfg.RESULT_TYPE = "csv"
print("Unknown output option. Using csv output.")
if __name__ == "__main__":
freeze_support() # Freeze support for executable
parser = argparse.ArgumentParser(description="Analyze audio files with BattyBirdNET")
add_parser_arguments()
args = parser.parse_args()
set_paths()
load_ebird_codes()
set_custom_classifier()
check_result_type()
set_analysis_location()
load_translated_labels()
load_species_list()
parse_input_files()
set_analysis_parameters()
set_hardware_parameters()
# Add config items to each file list entry.
# We have to do this for Windows which does not
# support fork() and thus each process has to
# have its own config. USE LINUX!
flist = [(f, cfg.get_config()) for f in cfg.FILE_LIST]
# Analyze files
if cfg.CPU_THREADS < 2:
for entry in flist:
analyze_file(entry)
else:
with Pool(cfg.CPU_THREADS) as p:
p.map(analyze_file, flist)
if args.segment == "on" or args.spectrum == "on":
subprocess.run(["python3", "segments.py"])
if args.spectrum == "on":
# iterate through the segements folder subfolders, call the plotter
print("Spectrums in progress ...")
script_dir = os.path.dirname(os.path.abspath(sys.argv[0]))
root_dir = pathlib.Path(os.path.join(script_dir, args.i + "/segments"))
for dir_name in os.listdir(root_dir):
f = os.path.join(root_dir, dir_name)
if not os.path.isfile(f):
print("Spectrum in progres for: " + f)
cmd = ['python3', "batchspec.py", f, f]
subprocess.run(cmd)
# A few examples to test
# python3 analyze.py --i example/ --o example/ --slist example/ --min_conf 0.5 --threads 4
# python3 analyze.py --i example/soundscape.wav --o example/soundscape.BirdNET.selection.table.txt --slist example/species_list.txt --threads 8
# python3 analyze.py --i example/ --o example/ --lat 42.5 --lon -76.45 --week 4 --sensitivity 1.0 --rtype table --locale de