File size: 2,456 Bytes
b9f0099 1ae96c8 88d1c0e 2088911 b9f0099 1ae96c8 88d1c0e b9f0099 1ae96c8 88d1c0e b9f0099 1ae96c8 88d1c0e b9f0099 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import functools
from lm_eval.api.metrics import mean
def process_results_decorator(func):
# This decorator processes the results of a task before passing them to the original process_results function
@functools.wraps(func)
def wrapper(self, doc, results, *args, **kwargs):
# We process the results here
processed_results = [r[0] for r in results]
end_to_end_time = sum([r[1] for r in results]) / len(results)
prefilling_time = sum([r[2] for r in results]) / len(results)
decoding_throughput = sum([r[3] for r in results]) / len(results)
# print(f"end_to_end_time: {end_to_end_time}, prefilling_time: {prefilling_time}, decoding_throughput: {decoding_throughput}")
# Now call the original process_results with the processed results
result_dict = func(self, doc, processed_results, *args, **kwargs)
result_dict["end_to_end_time"] = end_to_end_time
result_dict["prefilling_time"] = prefilling_time
result_dict["decoding_throughput"] = decoding_throughput
return result_dict
return wrapper
def aggregation_decorator(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
aggregation_list = func(self, *args, **kwargs)
aggregation_list["end_to_end_time"] = mean
aggregation_list["prefilling_time"] = mean
aggregation_list["decoding_throughput"] = mean
return aggregation_list
return wrapper
def higher_is_better_decorator(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
higher_is_better_dict = func(self, *args, **kwargs)
higher_is_better_dict["end_to_end_time"] = False
higher_is_better_dict["prefilling_time"] = False
higher_is_better_dict["decoding_throughput"] = True
return higher_is_better_dict
return wrapper
def measure_system_metrics(cls):
method_decorators = {
'process_results': [process_results_decorator],
'aggregation': [aggregation_decorator],
'higher_is_better': [higher_is_better_decorator],
}
for method_name, decorators in method_decorators.items():
if callable(getattr(cls, method_name, None)):
original_method = getattr(cls, method_name)
for decorator in reversed(decorators):
original_method = decorator(original_method)
setattr(cls, method_name, original_method)
return cls
|