ECG2HRV / src /deprecated /pipeline_wrapper.py
nina-m-m
Refactor deprecated code
9e4b4a3
raw
history blame
2.05 kB
from transformers import Pipeline
from src.deprecated.conversion import csv_to_pandas
from src.deprecated.pydantic_models import ECGConfig, ECGSample
from src.deprecated.ecg_processing import process_batch
class MyPipeline(Pipeline):
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
if "maybe_arg" in kwargs:
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
return preprocess_kwargs, {}, {}
def preprocess(self, inputs: str) -> dict:
# inputs are csv files
df = csv_to_pandas(inputs)
# Implode
cols_to_implode = ['timestamp_idx', 'ecg', 'label']
df_imploded = df.groupby(list(set(df.columns) - set(cols_to_implode))) \
.agg({'timestamp_idx': list,
'ecg': list,
'label': list}) \
.reset_index()
# Get metadata
config_cols = [col for col in df.columns if col.startswith('configs.')]
configs = df_imploded[config_cols].iloc[0].to_dict()
configs = {key.removeprefix('configs.'): value for key, value in configs.items()}
configs = ECGConfig(**configs)
batch_cols = [col for col in df.columns if col.startswith('batch.')]
batch = df_imploded[batch_cols].iloc[0].to_dict()
batch = {key.removeprefix('batch.'): value for key, value in batch.items()}
# Get samples
samples = df_imploded.to_dict(orient='records')
samples = [ECGSample(**sample) for sample in samples]
model_input = {"samples": samples, "configs": configs, "batch": batch}
return {"model_input": model_input}
def _forward(self, model_inputs):
# model_inputs == {"model_input": model_input}
samples = model_inputs["model_input"]["samples"]
configs = model_inputs["model_input"]["configs"]
batch = model_inputs["model_input"]["batch"]
features_df = process_batch(samples, configs)
return features_df
def postprocess(self, model_outputs):
return model_outputs