File size: 2,053 Bytes
32730d9 9e4b4a3 875bdf8 32730d9 875bdf8 32730d9 875bdf8 32730d9 875bdf8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
from transformers import Pipeline
from src.deprecated.conversion import csv_to_pandas
from src.deprecated.pydantic_models import ECGConfig, ECGSample
from src.deprecated.ecg_processing import process_batch
class MyPipeline(Pipeline):
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
if "maybe_arg" in kwargs:
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
return preprocess_kwargs, {}, {}
def preprocess(self, inputs: str) -> dict:
# inputs are csv files
df = csv_to_pandas(inputs)
# Implode
cols_to_implode = ['timestamp_idx', 'ecg', 'label']
df_imploded = df.groupby(list(set(df.columns) - set(cols_to_implode))) \
.agg({'timestamp_idx': list,
'ecg': list,
'label': list}) \
.reset_index()
# Get metadata
config_cols = [col for col in df.columns if col.startswith('configs.')]
configs = df_imploded[config_cols].iloc[0].to_dict()
configs = {key.removeprefix('configs.'): value for key, value in configs.items()}
configs = ECGConfig(**configs)
batch_cols = [col for col in df.columns if col.startswith('batch.')]
batch = df_imploded[batch_cols].iloc[0].to_dict()
batch = {key.removeprefix('batch.'): value for key, value in batch.items()}
# Get samples
samples = df_imploded.to_dict(orient='records')
samples = [ECGSample(**sample) for sample in samples]
model_input = {"samples": samples, "configs": configs, "batch": batch}
return {"model_input": model_input}
def _forward(self, model_inputs):
# model_inputs == {"model_input": model_input}
samples = model_inputs["model_input"]["samples"]
configs = model_inputs["model_input"]["configs"]
batch = model_inputs["model_input"]["batch"]
features_df = process_batch(samples, configs)
return features_df
def postprocess(self, model_outputs):
return model_outputs
|