File size: 4,131 Bytes
490def8
26cdd43
47913ac
26cdd43
 
 
 
9565da9
26cdd43
 
 
fab1822
 
490def8
 
26cdd43
490def8
47913ac
 
 
 
 
 
 
 
 
 
 
 
490def8
26cdd43
 
 
 
 
 
 
490def8
26cdd43
 
47913ac
26cdd43
490def8
26cdd43
490def8
26cdd43
490def8
26cdd43
 
 
490def8
47913ac
26cdd43
47913ac
 
 
 
 
 
 
26cdd43
 
47913ac
 
26cdd43
 
47913ac
26cdd43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47913ac
26cdd43
 
 
47913ac
26cdd43
 
490def8
26cdd43
 
 
490def8
26cdd43
 
490def8
26cdd43
 
 
 
 
 
 
490def8
 
 
26cdd43
 
 
 
 
 
 
490def8
26cdd43
 
 
 
 
 
 
 
 
 
 
 
 
490def8
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# Import necessary libraries
import gradio as gr
import sys
from huggingface_hub import ModelCard, HfApi
import requests
import networkx as nx
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from collections import defaultdict
from networkx.drawing.nx_pydot import graphviz_layout
from io import BytesIO
from PIL import Image

# Define the model ID
MODEL_ID = "mlabonne/NeuralBeagle14-7B"

# Define a class to cache model cards
class CachedModelCard(ModelCard):
  _cache = {}

  @classmethod
  def load(cls, model_id: str, **kwargs) -> "ModelCard":
    if model_id not in cls._cache:
      try:
        cls._cache[model_id] = super().load(model_id, **kwargs)
      except:
        cls._cache[model_id] = None
    return cls._cache[model_id]

# Function to get model names from a YAML file
def get_model_names_from_yaml(url):
    model_tags = []
    response = requests.get(url)
    if response.status_code == 200:
        model_tags.extend([item for item in response.content if '/' in str(item)])
    return model_tags

# Function to get the color of the model based on its license
def get_license_color(model):
    try:
        card = CachedModelCard.load(model)
        license = card.data.to_dict()['license'].lower()
        permissive_licenses = ['mit', 'bsd', 'apache-2.0', 'openrail']
        if any(perm_license in license for perm_license in permissive_licenses):
            return 'lightgreen'
        else:
            return 'lightcoral'
    except Exception as e:
        return 'lightgray'

# Function to find model names in the family tree
def get_model_names(model, genealogy, found_models=None, visited_models=None):
    if found_models is None:
        found_models = set()
    if visited_models is None:
        visited_models = set()

    if model in visited_models:
        return found_models
    visited_models.add(model)

    try:
        card = CachedModelCard.load(model)
        card_dict = card.data.to_dict()
        license = card_dict['license']

        model_tags = []
        if 'base_model' in card_dict:
            model_tags = card_dict['base_model']

        if 'tags' in card_dict and not model_tags:
            tags = card_dict['tags']
            model_tags = [model_name for model_name in tags if '/' in model_name]

        if not model_tags:
            model_tags.extend(get_model_names_from_yaml(f"https://huggingface.co/{model}/blob/main/merge.yml"))
        if not model_tags:
            model_tags.extend(get_model_names_from_yaml(f"https://huggingface.co/{model}/blob/main/mergekit_config.yml"))

        if not isinstance(model_tags, list):
            model_tags = [model_tags] if model_tags else []

        found_models.add(model)

        for model_tag in model_tags:
            genealogy[model_tag].append(model)
            get_model_names(model_tag, genealogy, found_models, visited_models)

    except Exception as e:
        pass

    return found_models

# Function to create the family tree
def create_family_tree(start_model):
    genealogy = defaultdict(list)
    get_model_names(start_model, genealogy)

    G = nx.DiGraph()

    for parent, children in genealogy.items():
        for child in children:
            G.add_edge(parent, child)

    max_depth = nx.dag_longest_path_length(G) + 1
    max_width = max_width_of_tree(G) + 1

    height = max(8, 1.6 * max_depth)
    width = max(8, 6 * max_width)

    plt.figure(figsize=(width, height))
    pos = graphviz_layout(G, prog="dot")

    node_colors = [get_license_color(node) for node in G.nodes()]
    clear_output()

    labels = {node: node.replace("/", "\n") for node in G.nodes()}

    nx.draw(G, pos, labels=labels, with_labels=True, node_color=node_colors, font_size=12, node_size=8_000, edge_color='black')

    legend_elements = [
        Patch(facecolor='lightgreen', label='Permissive'),
        Patch(facecolor='lightcoral', label='Noncommercial'),
        Patch(facecolor='lightgray', label='Unknown')
    ]
    plt.legend(handles=legend_elements, loc='upper left')

    plt.title(f"{start_model}'s Family Tree", fontsize=20)
    plt.show()

create_family_tree(MODEL_ID)