|
import os |
|
import time |
|
import csv |
|
import datetime |
|
import gradio |
|
import schedule |
|
from gradio import utils |
|
import huggingface_hub |
|
from pathlib import Path |
|
from src.models.bert import BERTClassifier |
|
from src.utils.utilities import Utility |
|
|
|
model = BERTClassifier(model_name='jeevavijay10/nlp-goemotions-bert') |
|
|
|
classes = Utility().read_emotion_list() |
|
|
|
hf_token = os.getenv("HF_TOKEN") |
|
|
|
dataset_dir = "logs" |
|
|
|
headers = ["input", "output", "timestamp", "elapsed"] |
|
|
|
|
|
repo = huggingface_hub.Repository( |
|
local_dir=dataset_dir, |
|
clone_from="https://huggingface.co/datasets/jeevavijay10/senti-pred-gradio", |
|
token=hf_token, |
|
) |
|
repo.git_pull(lfs=True) |
|
|
|
def log_record(vals): |
|
log_file = Path(dataset_dir) / "data.csv" |
|
is_new = not Path(log_file).exists() |
|
with open(log_file, "a", newline="", encoding="utf-8") as csvfile: |
|
writer = csv.writer(csvfile) |
|
if is_new: |
|
writer.writerow(utils.sanitize_list_for_csv(headers)) |
|
writer.writerow(utils.sanitize_list_for_csv(vals)) |
|
schedule.run_pending() |
|
print(f"Last Sync: {job.last_run}") |
|
|
|
def predict(sentence): |
|
|
|
timestamp = datetime.datetime.now().isoformat() |
|
start_time = time.time() |
|
predictions = model.evaluate([sentence]) |
|
elapsed_time = time.time() - start_time |
|
|
|
output = classes[predictions[0]] |
|
|
|
print(f"Sentence: {sentence} \nPrediction: {predictions[0]} - {output}") |
|
log_record([sentence, output, timestamp, str(elapsed_time)]) |
|
|
|
return output |
|
|
|
|
|
def sync_logs(): |
|
print(f"Repo Clean: {repo.is_repo_clean()}") |
|
if not repo.is_repo_clean(): |
|
repo.git_add() |
|
repo.git_commit() |
|
repo.git_pull(lfs=True) |
|
result = repo.git_push() |
|
|
|
print(result) |
|
|
|
job = schedule.every(5).minutes.do(sync_logs) |
|
print("Scheduler engaged") |
|
|
|
gradio.Interface( |
|
fn=predict, |
|
inputs="text", |
|
outputs="text", |
|
allow_flagging='never' |
|
).launch() |
|
|