File size: 1,557 Bytes
523fad9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
719c272
 
523fad9
 
 
 
 
 
719c272
523fad9
719c272
 
523fad9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import asyncio

import gradio as gr

import src.constants as constants
from src.hub import list_models, load_model_card


async def load_model_tree(result_paths_per_model, model_ids):
    # TODO: Multiple models?
    model_id = model_ids[0]
    model_tree = await asyncio.gather(
        load_base_models(model_id),
        *[
            load_derived_models_by_type(model_id, derived_model_type[1])
            for derived_model_type in constants.DERIVED_MODEL_TYPES
        ],
    )
    model_tree_choices = [
        [model_id for model_id in model_ids if model_id in result_paths_per_model] for model_ids in model_tree
    ]
    model_tree_labels = [constants.BASE_MODEL_TYPE[0]] + [
        derived_model_type[0] for derived_model_type in constants.DERIVED_MODEL_TYPES
    ]
    return [
        gr.Dropdown(choices=choices, label=f"{label} ({len(choices)})", interactive=True if choices else False)
        for choices, label in zip(model_tree_choices, model_tree_labels)
    ]


async def load_base_models(model_id) -> list[str]:
    card = await load_model_card(model_id)
    if not card:
        return []
    base_models = getattr(card.data, constants.BASE_MODEL_TYPE[1])
    if not isinstance(base_models, list):
        base_models = [base_models]
    return base_models


async def load_derived_models_by_type(model_id, derived_model_type) -> list[str]:
    models = await list_models(filtering=f"base_model:{derived_model_type}:{model_id}")
    if not models:
        return []
    models = [model["id"] for model in models]
    return models