Skip to content

visualize

darts_model_plot(genotype, full_label=False, param_list=(), input_labels=(), out_dim=None, out_fnc=None, decimals_to_display=2)

Generates a graphviz plot for a DARTS model based on the genotype of the model.

Parameters:

Name Type Description Default
genotype Genotype

the genotype of the model

required
full_label bool

if True, the labels of the nodes will be the full name of the operation (including the coefficients)

False
param_list typing.Sequence

a list of parameters to be included in the labels of the nodes

()
input_labels typing.Sequence

a list of labels to be included in the input nodes

()
out_dim int

the number of output nodes of the model

None
out_fnc str

the (activation) function to be used for the output nodes

None
decimals_to_display int

number of decimals to include in parameter values on plot

2
Source code in autora/theorist/darts/visualize.py
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
def darts_model_plot(
    genotype: Genotype,
    full_label: bool = False,
    param_list: typing.Sequence = (),
    input_labels: typing.Sequence = (),
    out_dim: int = None,
    out_fnc: str = None,
    decimals_to_display: int = 2,
) -> Digraph:
    """
    Generates a graphviz plot for a DARTS model based on the genotype of the model.

    Arguments:
        genotype: the genotype of the model
        full_label: if True, the labels of the nodes will be the full name of the operation
            (including the coefficients)
        param_list: a list of parameters to be included in the labels of the nodes
        input_labels: a list of labels to be included in the input nodes
        out_dim: the number of output nodes of the model
        out_fnc: the (activation) function to be used for the output nodes
        decimals_to_display: number of decimals to include in parameter values on plot
    """

    format_string = "{:." + "{:.0f}".format(decimals_to_display) + "f}"

    graph = Digraph(
        edge_attr=dict(fontsize="20", fontname="times"),
        node_attr=dict(
            style="filled",
            shape="rect",
            align="center",
            fontsize="20",
            height="0.5",
            width="0.5",
            penwidth="2",
            fontname="times",
        ),
        engine="dot",
    )
    graph.body.extend(["rankdir=LR"])

    for input_node in input_labels:
        graph.node(input_node, fillcolor="#F1EDB9")  # fillcolor='darkseagreen2'
    # assert len(genotype) % 2 == 0

    # determine number of steps (intermediate nodes)
    steps = 0
    for op, j in genotype:
        if j == 0:
            steps += 1

    for i in range(steps):
        graph.node("k" + str(i + 1), fillcolor="#BBCCF9")  # fillcolor='lightblue'

    params_counter = 0
    n = len(input_labels)
    start = 0
    for i in range(steps):
        end = start + n
        _logger.debug(start, end)
        # for k in [2*i, 2*i + 1]:
        for k in range(
            start, end
        ):  # adapted this iteration from get_genotype() in model_search.py
            _logger.debug(genotype, k)
            op, j = genotype[k]
            if j < len(input_labels):
                u = input_labels[j]
            else:
                u = "k" + str(j - len(input_labels) + 1)
            v = "k" + str(i + 1)
            params_counter = k
            if op != "none":
                op_label = op
                if full_label:
                    params = param_list[
                        start + j
                    ]  # note: genotype order and param list order don't align
                    op_label = get_operation_label(
                        op, params, decimals=decimals_to_display
                    )
                    graph.edge(u, v, label=op_label, fillcolor="gray")
                else:
                    graph.edge(
                        u,
                        v,
                        label="(" + str(j + start) + ") " + op_label,
                        fillcolor="gray",
                    )  # '(' + str(k) + ') '
        start = end
        n += 1

    # determine output nodes

    out_nodes = list()
    if out_dim is None:
        out_nodes.append("out")
    else:
        biases = None
        if full_label:
            params = param_list[params_counter + 1]
            if len(params) > 1:
                biases = params[1]  # first node contains biases

        for idx in range(out_dim):
            out_str = ""
            # specify node ID
            if out_fnc is not None:
                out_str = out_str + out_fnc + "(r_" + str(idx)
            else:
                out_str = "(r_" + str(idx)

            if out_dim == 1:
                if out_fnc is not None:
                    out_str = "P(detected) = " + out_fnc + "(x"
                else:
                    # out_str = 'dx_1 = (x'
                    out_str = "P_n = (x"

            # if available, add bias
            if biases is not None:
                out_str = out_str + " + " + format_string.format(biases[idx]) + ")"
            else:
                out_str = out_str + ")"

            # add node
            graph.node(out_str, fillcolor="#CBE7C7")  # fillcolor='palegoldenrod'
            out_nodes.append(out_str)

    for i in range(steps):
        u = "k" + str(i + 1)
        if full_label:
            params_org = param_list[params_counter + 1 + i]  # count from k
            for out_idx, out_str in enumerate(out_nodes):
                params = list()
                params.append(params_org[0][out_idx])
                op_label = get_operation_label(
                    "classifier", params, decimals=decimals_to_display
                )
                graph.edge(u, out_str, label=op_label, fillcolor="gray")
        else:
            for out_idx, out_str in enumerate(out_nodes):
                graph.edge(u, out_str, label="linear", fillcolor="gray")

    return graph

plot(genotype, filename, file_format='pdf', view_file=None, full_label=False, param_list=(), input_labels=(), out_dim=None, out_fnc=None)

Generates a graphviz plot for a DARTS model based on the genotype of the model.

Parameters:

Name Type Description Default
genotype Genotype

the genotype of the model

required
filename str

the filename of the output file

required
file_format str

the format of the output file

'pdf'
view_file bool

if True, the plot will be displayed in a window

None
full_label bool

if True, the labels of the nodes will be the full name of the operation (including the coefficients)

False
param_list typing.Tuple

a list of parameters to be included in the labels of the nodes

()
input_labels typing.Tuple

a list of labels to be included in the input nodes

()
out_dim int

the number of output nodes of the model

None
out_fnc str

the (activation) function to be used for the output nodes

None
Source code in autora/theorist/darts/visualize.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def plot(
    genotype: Genotype,
    filename: str,
    file_format: str = "pdf",
    view_file: bool = None,
    full_label: bool = False,
    param_list: typing.Tuple = (),
    input_labels: typing.Tuple = (),
    out_dim: int = None,
    out_fnc: str = None,
):
    """
    Generates a graphviz plot for a DARTS model based on the genotype of the model.

    Arguments:
        genotype: the genotype of the model
        filename: the filename of the output file
        file_format: the format of the output file
        view_file: if True, the plot will be displayed in a window
        full_label: if True, the labels of the nodes will be the full name of the operation
            (including the coefficients)
        param_list: a list of parameters to be included in the labels of the nodes
        input_labels: a list of labels to be included in the input nodes
        out_dim: the number of output nodes of the model
        out_fnc: the (activation) function to be used for the output nodes
    """

    g = darts_model_plot(
        genotype=genotype,
        full_label=full_label,
        param_list=param_list,
        input_labels=input_labels,
        out_dim=out_dim,
        out_fnc=out_fnc,
    )

    if view_file is None:
        if file_format == "pdf":
            view_file = True
        else:
            view_file = False

    g.render(filename, view=view_file, format=file_format)