|
import functools |
|
from lm_eval.api.metrics import mean |
|
|
|
|
|
def process_results_decorator(func): |
|
|
|
@functools.wraps(func) |
|
def wrapper(self, doc, results, *args, **kwargs): |
|
|
|
processed_results = [r[0] for r in results] |
|
|
|
|
|
|
|
|
|
|
|
end_to_end_time = sum([r[1] for r in results]) / len(results) |
|
prefilling_time = sum([r[2] for r in results]) / len(results) |
|
token_per_sec = sum([r[3] for r in results]) / len(results) |
|
print(f"end_to_end_time: {end_to_end_time}, prefilling_time: {prefilling_time}, token_per_sec: {token_per_sec}") |
|
|
|
|
|
result_dict = func(self, doc, processed_results, *args, **kwargs) |
|
result_dict["end_to_end_time"] = end_to_end_time |
|
result_dict["prefilling_time"] = prefilling_time |
|
result_dict["token_per_sec"] = token_per_sec |
|
return result_dict |
|
return wrapper |
|
|
|
|
|
def aggregation_decorator(func): |
|
@functools.wraps(func) |
|
def wrapper(self, *args, **kwargs): |
|
aggregation_list = func(self, *args, **kwargs) |
|
aggregation_list["end_to_end_time"] = mean |
|
aggregation_list["prefilling_time"] = mean |
|
aggregation_list["token_per_sec"] = mean |
|
return aggregation_list |
|
return wrapper |
|
|
|
|
|
def higher_is_better_decorator(func): |
|
@functools.wraps(func) |
|
def wrapper(self, *args, **kwargs): |
|
higher_is_better_dict = func(self, *args, **kwargs) |
|
higher_is_better_dict["end_to_end_time"] = False |
|
higher_is_better_dict["prefilling_time"] = False |
|
higher_is_better_dict["token_per_sec"] = True |
|
return higher_is_better_dict |
|
return wrapper |
|
|
|
|
|
def measure_system_metrics(cls): |
|
method_decorators = { |
|
'process_results': [process_results_decorator], |
|
'aggregation': [aggregation_decorator], |
|
'higher_is_better': [higher_is_better_decorator], |
|
} |
|
for method_name, decorators in method_decorators.items(): |
|
if callable(getattr(cls, method_name, None)): |
|
original_method = getattr(cls, method_name) |
|
for decorator in reversed(decorators): |
|
original_method = decorator(original_method) |
|
setattr(cls, method_name, original_method) |
|
return cls |
|
|