File size: 18,036 Bytes
54f2589
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
## Fetch Model Registry and clemscores
import requests
import pandas as pd
from datetime import datetime
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import numpy as np

from src.assets.text_content import REGISTRY_URL, REPO, BENCHMARK_FILE
from src.leaderboard_utils import get_github_data

# Cut-off date from where to start the trendgraph
START_DATE = '2023-06-01'

def get_param_size(params: str) -> float:
    """Convert parameter size from string to float.

    Args:
        params (str): The parameter size as a string (e.g., '1000B', '1T').

    Returns:
        float: The size of parameters in float.
    """
    if not params:
        param_size = 0
    else:
        if params[-1] == "B":
            param_size = params[:-1]
            param_size = float(param_size)
        elif params[-1] == "T":
            param_size = params[:-1]
            param_size = float(param_size)
            param_size *= 1000
        else:
            print("Not a valid parameter size")

    return param_size

def date_difference(date_str1: str, date_str2: str) -> int:
    """Calculate the difference in days between two dates.

    Args:
        date_str1 (str): The first date as a string in 'YYYY-MM-DD' format.
        date_str2 (str): The second date as a string in 'YYYY-MM-DD' format.

    Returns:
        int: The difference in days between the two dates.
    """
    date_format = "%Y-%m-%d"
    date1 = datetime.strptime(date_str1, date_format)
    date2 = datetime.strptime(date_str2, date_format)
    return (date1 - date2).days


def populate_list(df: pd.DataFrame, abs_diff: float) -> list:
    """Create a list of models based on clemscore differences.

    Args:
        df (pd.DataFrame): DataFrame containing model data.
        abs_diff (float): The absolute difference threshold for clemscore.

    Returns:
        list: A list of model names that meet the criteria.
    """
    l = [df.iloc[0]['model']]
    prev_clemscore = df.iloc[0]['clemscore']
    prev_date = df.iloc[0]['release_date']

    for i in range(1, len(df)):
        curr_clemscore = df.iloc[i]['clemscore']
        curr_date = df.iloc[i]['release_date']
        date_diff = date_difference(curr_date, prev_date)

        if curr_clemscore - prev_clemscore >= abs_diff:
            if date_diff == 0:
                l[-1] = df.iloc[i]['model']
            else:
                l.append(df.iloc[i]['model'])

            prev_clemscore = curr_clemscore
            prev_date = curr_date

    # # Add the last model if the difference between the last and previous date is greater than 15 days
    # last_date = df.iloc[-1]['release_date']
    # if date_difference(last_date, prev_date) > 15:
    #     l.append(df.iloc[-1]['model'])

    return l


def get_models_to_display(result_df: pd.DataFrame, open_dip: float = 0, comm_dip: float = 0) -> tuple:
    """Retrieve models to display based on clemscore differences.

    Args:
        result_df (pd.DataFrame): DataFrame containing model data.
        open_dip (float, optional): Threshold for open models. Defaults to 0.
        comm_dip (float, optional): Threshold for commercial models. Defaults to 0.

    Returns:
        tuple: Two lists of model names (open and commercial).
    """
    open_model_df = result_df[result_df['open_weight']==True]
    comm_model_df = result_df[result_df['open_weight']==False]

    open_model_df = open_model_df.sort_values(by='release_date', ascending=True)
    comm_model_df = comm_model_df.sort_values(by='release_date', ascending=True)
    open_models = populate_list(open_model_df, open_dip)
    comm_models = populate_list(comm_model_df, comm_dip)
    return open_models, comm_models


def get_trend_data(text_dfs: list, model_registry_data: list) -> pd.DataFrame:
    """Process text data frames to extract model information.

    Args:
        text_dfs (list): List of DataFrames containing model information.
        model_registry_data (list): List of dictionaries containing model registry data.

    Returns:
        pd.DataFrame: DataFrame containing processed model data.
    """
    visited = set()  # Track models that have been processed
    result_df = pd.DataFrame(columns=['model', 'clemscore', 'open_weight', 'release_date', 'parameters', 'est_flag'])

    for df in text_dfs:
        for i in range(len(df)):
            model_name = df['Model'].iloc[i]
            if model_name not in visited:
                visited.add(model_name)
                for dict_obj in model_registry_data:
                    if dict_obj["model_name"] == model_name:
                        if dict_obj["parameters"] == "" :
                            params = "1000B"
                            est_flag = True
                        else:
                            params = dict_obj['parameters']
                            est_flag = False

                        param_size = get_param_size(params)
                        new_data = {'model': model_name, 'clemscore': df['Clemscore'].iloc[i], 'open_weight':dict_obj['open_weight'],
                                    'release_date': dict_obj['release_date'], 'parameters': param_size, 'est_flag': est_flag}
                        result_df.loc[len(result_df)] = new_data
                        break
    return result_df  # Return the compiled DataFrame


def get_plot(df: pd.DataFrame, start_date: str = '2023-06-01', end_date: str = '2024-12-30',
             benchmark_ticks: dict = {}, benchmark_update = {}, **plot_kwargs) -> go.Figure:
    """Generate a scatter plot for the given DataFrame.

    Args:
        df (pd.DataFrame): DataFrame containing model data.
        start_date (str, optional): Start date for filtering. Defaults to '2023-06-01'.
        end_date (str, optional): End date for filtering. Defaults to '2024-12-30'.
        benchmark_ticks (dict, optional): Custom benchmark ticks for the version dates. Defaults to {}.
        benchmark_update (dict, optional): Custom benchmark metadata containing last_updated date for the versions. Defaults to {}.
    
    Keyword Args:
        open_dip (float, optional): Threshold for open models' clemscore differences. Max dip in clemscore allowed to be considered in trend.
        comm_dip (float, optional): Threshold for commercial models' clemscore differences. Max dip in clemscore allowed to be considered in trend.
        height (int, optional): Height of the plot in pixels. Adjusted for mobile or desktop views.
        mobile_view (bool, optional): Flag to indicate if the plot should be optimized for mobile display. Defaults to False.

    Returns:
        go.Figure: The generated plot.
    """

    open_dip = plot_kwargs['open_dip']
    comm_dip = plot_kwargs['comm_dip']
    height = plot_kwargs['height']
    width = plot_kwargs['width']

    mobile_view = True if plot_kwargs['mobile_view'] else False

    max_clemscore = df['clemscore'].max()
    # Convert 'release_date' to datetime
    df['Release date'] = pd.to_datetime(df['release_date'], format='ISO8601')
    # Filter out data before April 2023/START_DATE
    df = df[df['Release date'] >= pd.to_datetime(start_date)]
    open_model_list, comm_model_list = get_models_to_display(df, open_dip, comm_dip)    
    models_to_display = open_model_list + comm_model_list
    print(f"open_model_list: {open_model_list}, comm_model_list: {comm_model_list}")

    # Create a column to indicate if the model should be labeled
    df['label_model'] = df['model'].apply(lambda x: x if x in models_to_display else "")

    # If mobile_view, then show only the models in models_to_display i.e. on the trend line #minimalistic
    if mobile_view:
        df = df[df['model'].isin(models_to_display)]

    # Add an identifier column to each DataFrame
    df['Model Type'] = df['open_weight'].map({True: 'Open-Weight', False: 'Commercial'})

    marker_size = df['parameters'].apply(lambda x: np.sqrt(x) if x > 0 else np.sqrt(400)).astype(float)  # Arbitrary sqrt value to scale marker size based on parameter size

    open_color = 'red'
    comm_color = 'blue'

    # Create the scatter plot
    fig = px.scatter(df,
                    x="Release date",
                    y="clemscore",
                    color="Model Type",  # Differentiates the datasets by color
                    hover_name="model",
                    size=marker_size,
                    size_max=40,  # Max size of the circles
                    template="plotly_white",
                    hover_data={  # Customize hover information
                        "Release date": True,  # Show the release date
                        "clemscore": True,  # Show the clemscore
                        "Model Type": True  # Show the model type
                    },
                    custom_data=["model", "Release date", "clemscore"]  # Specify custom data columns for hover
                    )

    fig.update_traces(
        hovertemplate='Model Name: %{customdata[0]}<br>Release date: %{customdata[1]}<br>Clemscore: %{customdata[2]}<br>'
    )
    
    # Sort dataframes for line plotting
    df_open = df[df['model'].isin(open_model_list)].sort_values(by='Release date')
    df_commercial = df[df['model'].isin(comm_model_list)].sort_values(by='Release date')

    ## Custom tics for x axis
    # Define the start and end dates
    start_date = pd.to_datetime(start_date)
    end_date = pd.to_datetime(end_date)
    # Generate ticks every two months
    date_range = pd.date_range(start=start_date, end=end_date, freq='2MS')  # '2MS' stands for 2 Months Start frequency
    # Create labels for these ticks
    custom_ticks = {date: date.strftime('%b %Y') for date in date_range}

    ## Benchmark Version ticks
    benchmark_tickvals = list(pd.to_datetime(list(benchmark_ticks.keys())))
    custom_ticks = {k:v for k,v in custom_ticks.items() if k not in benchmark_tickvals}
    custom_tickvals = list(custom_ticks.keys())


    for date, version in benchmark_ticks.items():
        # Find the corresponding update date from benchmark_update based on the version name
        update_date = next((update_date for update_date, ver in benchmark_update.items() if version in ver), None)

        if update_date:
            # Add vertical black dotted line for each benchmark_tick date
            fig.add_shape(
                go.layout.Shape(
                    type='line',
                    x0=date,
                    x1=date,
                    y0=0,
                    y1=1,
                    yref='paper',
                    line=dict(color='#A9A9A9', dash='dash'),  # Black dotted line
                )
            )

            # Add hover information across the full y-axis range
            fig.add_trace(
                go.Scatter(
                    x=[date]*100,
                    y=list(range(0,100)),  # Covers full y-axis range
                    mode='markers',
                    line=dict(color='rgba(255,255,255,0)', width=0),  # Fully transparent line
                    hovertext=[
                        f"Version: {version} released on {date.strftime('%d %b %Y')}, last updated on: {update_date.strftime('%d %b %Y')}" 
                        for _ in range(100)
                    ],  # Unique hovertext for all points
                    hoverinfo="text",
                    hoveron='points',
                    showlegend=False
                )
            )


    if mobile_view:
        # Remove custom_tickvals within -1 month to +1 month of benchmark_tickvals for better visibility
        one_month = pd.DateOffset(months=1)
        filtered_custom_tickvals = [
            date for date in custom_tickvals 
            if not any((benchmark_date - one_month <= date <= benchmark_date + one_month) for benchmark_date in benchmark_tickvals)
        ]
        # Alternate <br> for benchmark ticks based on date difference (Eg. v1.6, v1.6.5 too close to each other for MM benchmark)
        benchmark_tick_texts = []
        for i in range(len(benchmark_tickvals)):
            if i == 0:
                benchmark_tick_texts.append(f"<br><br><b>{benchmark_ticks[benchmark_tickvals[i]]}</b>")
            else:
                date_diff = (benchmark_tickvals[i] - benchmark_tickvals[i - 1]).days
                if date_diff <= 75:
                    benchmark_tick_texts.append(f"<br><br><br><b>{benchmark_ticks[benchmark_tickvals[i]]}</b>")
                else:
                    benchmark_tick_texts.append(f"<br><br><b>{benchmark_ticks[benchmark_tickvals[i]]}</b>")
        fig.update_xaxes(
            tickvals=filtered_custom_tickvals + benchmark_tickvals,  # Use filtered_custom_tickvals
            ticktext=[f"{date.strftime('%b')}<br>{date.strftime('%y')}" for date in filtered_custom_tickvals] + 
                      benchmark_tick_texts,  # Use the new benchmark tick texts
            tickangle=0,
            tickfont=dict(size=10)
        )
        fig.update_yaxes(range=[0, 110]) # Set y-axis range to 110 for better visibility of legend and avoiding overlap with interactivity block of plotly on top-right
        display_mode = 'lines+markers'
    else:
        fig.update_xaxes(
            tickvals=custom_tickvals + benchmark_tickvals,  # Use filtered_custom_tickvals
            ticktext=[f"{date.strftime('%b')} {date.strftime('%Y')}" for date in custom_tickvals] + 
                    [f"<br><span style='font-size:12px;'><b>{benchmark_ticks[date]}</b></span>" for date in benchmark_tickvals],  # Added <br> for vertical alignment
            tickangle=0,
            tickfont=dict(size=10)  
        )
        fig.update_yaxes(range=[0, max_clemscore+10])
        display_mode = 'lines+markers+text'


    # Add lines connecting the points for open models
    fig.add_trace(go.Scatter(x=df_open['Release date'], y=df_open['clemscore'],
                            mode=display_mode,  # Include 'text' in the mode
                            name='Open Models Trendline',
                            text=df_open['label_model'],  # Use label_model for text labels
                            textposition='top center',  # Position of the text labels
                            line=dict(color=open_color), showlegend=False))

    # Add lines connecting the points for commercial models
    fig.add_trace(go.Scatter(x=df_commercial['Release date'], y=df_commercial['clemscore'],
                            mode=display_mode,  # Include 'text' in the mode
                            name='Commercial Models Trendline',
                            text=df_commercial['label_model'],  # Use label_model for text labels
                            textposition='top center',  # Position of the text labels
                            line=dict(color=comm_color), showlegend=False))


    # Update layout to ensure text labels are visible   
    fig.update_traces(textposition='top center')

    # Update the Legend Position and plot dimensions
    fig.update_layout(height=height,
        legend=dict(
            yanchor="top",
            y=0.99,
            xanchor="left",
            x=0.01
        ) 
    )

    if width:
        print("Custom Seting the Width :")
        fig.update_layout(width=width)

    return fig

def get_final_trend_plot(benchmark: str = "Text", mobile_view: bool = False) -> go.Figure:
    """Fetch and generate the final trend plot for all models.

    Args:
        benchmark (str, optional): The benchmark type to use. Defaults to "Text".
        mobile_view (bool, optional): Flag to indicate mobile view. Defaults to False.

    Returns:
        go.Figure: The generated trend plot for selected benchmark.
    """
    # Fetch Model Registry
    response = requests.get(REGISTRY_URL)
    model_registry_data = response.json()
    # Custom tick labels
    json_url = REPO + BENCHMARK_FILE
    response = requests.get(json_url)

    # Check if the JSON file request was successful
    if response.status_code != 200:
        print(f"Failed to read JSON file: Status Code: {response.status_code}")

    json_data = response.json()
    versions = json_data['versions']

    if mobile_view:
        height = 450
        width = 375
    else:
        height = 1000
        width = None

    plot_kwargs = {'height': height, 'width': width, 'open_dip': 0, 'comm_dip': 0,
                   'mobile_view': mobile_view}

    benchmark_ticks = {}
    benchmark_update = {}
    if benchmark == "Text":
        text_dfs = get_github_data()['text']['dataframes']
        text_result_df = get_trend_data(text_dfs, model_registry_data)
        ## Get benchmark tickvalues as dates for X-axis
        for ver in versions:
            if 'multimodal' not in ver['version']: # Skip MM specific benchmark dates
                benchmark_ticks[pd.to_datetime(ver['release_date'])] = ver['version']
                if pd.to_datetime(ver['last_updated']) not in benchmark_update:
                    benchmark_update[pd.to_datetime(ver['last_updated'])] = [ver['version']]
                else:
                    benchmark_update[pd.to_datetime(ver['last_updated'])].append(ver['version'])

        fig =  get_plot(text_result_df, start_date=START_DATE, end_date=datetime.now().strftime('%Y-%m-%d'), benchmark_ticks=benchmark_ticks, benchmark_update=benchmark_update, **plot_kwargs)
    else:
        mm_dfs = get_github_data()['multimodal']['dataframes']
        result_df = get_trend_data(mm_dfs, model_registry_data)
        df = result_df
        for ver in versions:
            if 'multimodal' in ver['version']:
                temp_ver = ver['version']
                temp_ver = temp_ver.replace('_multimodal', '')
                benchmark_ticks[pd.to_datetime(ver['release_date'])] = temp_ver ## MM benchmark dates considered after v1.6 (incl.)
                benchmark_update[pd.to_datetime(ver['last_updated'])] = temp_ver

        fig = get_plot(df, start_date=START_DATE, end_date=datetime.now().strftime('%Y-%m-%d'), benchmark_ticks=benchmark_ticks, benchmark_update=benchmark_update, **plot_kwargs)

    return fig