Skip to content

Plot utils

cycle_default_score(cycle, x_vals, y_true)

Calculates score for each cycle using the estimator's default scorer.

Parameters:

Name Type Description Default
cycle Cycle

AER Cycle object that has been run

required
x_vals np.ndarray

Test dataset independent values

required
y_true np.ndarray

Test dataset dependent values

required

Returns:

Type Description

List of scores by cycle

Source code in autora/cycle/plot_utils.py
513
514
515
516
517
518
519
520
521
522
523
524
525
def cycle_default_score(cycle: Cycle, x_vals: np.ndarray, y_true: np.ndarray):
    """
    Calculates score for each cycle using the estimator's default scorer.
    Args:
        cycle: AER Cycle object that has been run
        x_vals: Test dataset independent values
        y_true: Test dataset dependent values

    Returns:
        List of scores by cycle
    """
    l_scores = [s.score(x_vals, y_true) for s in cycle.data.theories]
    return l_scores

cycle_specified_score(scorer, cycle, x_vals, y_true, **kwargs)

Calculates score for each cycle using specified sklearn scoring function.

Parameters:

Name Type Description Default
scorer Callable

sklearn scoring function

required
cycle Cycle

AER Cycle object that has been run

required
x_vals np.ndarray

Test dataset independent values

required
y_true np.ndarray

Test dataset dependent values

required
**kwargs

Keyword arguments to send to scoring function

{}
Source code in autora/cycle/plot_utils.py
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
def cycle_specified_score(
    scorer: Callable, cycle: Cycle, x_vals: np.ndarray, y_true: np.ndarray, **kwargs
):
    """
    Calculates score for each cycle using specified sklearn scoring function.
    Args:
        scorer: sklearn scoring function
        cycle: AER Cycle object that has been run
        x_vals: Test dataset independent values
        y_true: Test dataset dependent values
        **kwargs: Keyword arguments to send to scoring function

    Returns:

    """
    # Get predictions
    if "y_pred" in inspect.signature(scorer).parameters.keys():
        l_y_pred = _theory_predict(cycle, x_vals, predict_proba=False)
    elif "y_score" in inspect.signature(scorer).parameters.keys():
        l_y_pred = _theory_predict(cycle, x_vals, predict_proba=True)

    # Score each cycle
    l_scores = []
    for y_pred in l_y_pred:
        l_scores.append(scorer(y_true, y_pred, **kwargs))

    return l_scores

plot_cycle_score(cycle, X, y_true, scorer=None, x_label='Cycle', y_label=None, figsize=rcParams['figure.figsize'], ylim=None, xlim=None, scorer_kw={}, plot_kw={})

Plots scoring metrics of cycle's theories given test data.

Parameters:

Name Type Description Default
cycle Cycle

AER Cycle object that has been run

required
X np.ndarray

Test dataset independent values

required
y_true np.ndarray

Test dataset dependent values

required
scorer Optional[Callable]

sklearn scoring function (optional)

None
x_label str

Label for x-axis

'Cycle'
y_label Optional[str]

Label for y-axis

None
figsize Tuple[float, float]

Optional figure size tuple in inches

rcParams['figure.figsize']
ylim Optional[Tuple[float, float]]

Optional limits for the y-axis as a tuple (lower, upper)

None
xlim Optional[Tuple[float, float]]

Optional limits for the x-axis as a tuple (lower, upper)

None
scorer_kw dict

Dictionary of keywords for scoring function if scorer is supplied.

{}
plot_kw dict

Dictionary of keywords to pass to matplotlib 'plot' function.

{}

Returns:

Type Description
plt.Figure

matplotlib.figure.Figure

Source code in autora/cycle/plot_utils.py
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
def plot_cycle_score(
    cycle: Cycle,
    X: np.ndarray,
    y_true: np.ndarray,
    scorer: Optional[Callable] = None,
    x_label: str = "Cycle",
    y_label: Optional[str] = None,
    figsize: Tuple[float, float] = rcParams["figure.figsize"],
    ylim: Optional[Tuple[float, float]] = None,
    xlim: Optional[Tuple[float, float]] = None,
    scorer_kw: dict = {},
    plot_kw: dict = {},
) -> plt.Figure:
    """
    Plots scoring metrics of cycle's theories given test data.
    Args:
        cycle: AER Cycle object that has been run
        X: Test dataset independent values
        y_true: Test dataset dependent values
        scorer: sklearn scoring function (optional)
        x_label: Label for x-axis
        y_label: Label for y-axis
        figsize: Optional figure size tuple in inches
        ylim: Optional limits for the y-axis as a tuple (lower, upper)
        xlim: Optional limits for the x-axis as a tuple (lower, upper)
        scorer_kw: Dictionary of keywords for scoring function if scorer is supplied.
        plot_kw: Dictionary of keywords to pass to matplotlib 'plot' function.

    Returns:
        matplotlib.figure.Figure
    """

    # Use estimator's default scoring method if specific scorer is not supplied
    if scorer is None:
        l_scores = cycle_default_score(cycle, X, y_true)
    else:
        l_scores = cycle_specified_score(scorer, cycle, X, y_true, **scorer_kw)

    with plt.rc_context(controller_plotting_rc_context):
        # Plotting
        fig, ax = plt.subplots(figsize=figsize)
        ax.plot(np.arange(len(cycle.data.theories)), l_scores, **plot_kw)

        # Adjusting axis limits
        if ylim:
            ax.set_ylim(*ylim)
        if xlim:
            ax.set_xlim(*xlim)

        # Labeling
        ax.set_xlabel(x_label)
        if y_label is None:
            if scorer is not None:
                y_label = scorer.__name__
            else:
                y_label = "Score"
        ax.set_ylabel(y_label)
        ax.xaxis.set_major_locator(MaxNLocator(integer=True))

    return fig

plot_results_panel_2d(cycle, iv_name=None, dv_name=None, steps=50, wrap=4, query=None, subplot_kw={}, scatter_previous_kw={}, scatter_current_kw={}, plot_theory_kw={})

Generates a multi-panel figure with 2D plots showing results of one AER cycle.

Observed data is plotted as a scatter plot with the current cycle colored differently than observed data from previous cycles. The current cycle's theory is plotted as a line over the range of the observed data.

Parameters:

Name Type Description Default
cycle Cycle

AER Cycle object that has been run

required
iv_name Optional[str]

Independent variable name. Name should match the name instantiated in the cycle object. Default will select the first.

None
dv_name Optional[str]

Single dependent variable name. Name should match the names instantiated in the cycle object. Default will select the first DV.

None
steps int

Number of steps to define the condition space to plot the theory.

50
wrap int

Number of panels to appear in a row. Example: 9 panels with wrap=3 results in a 3x3 grid.

4
query Optional[Union[List, slice]]

Query which cycles to plot with either a List of indexes or a slice. The slice must be constructed with the slice() function or np.s_[] index expression.

None
subplot_kw dict

Dictionary of keywords to pass to matplotlib 'subplot' function

{}
scatter_previous_kw dict

Dictionary of keywords to pass to matplotlib 'scatter' function that plots the data points from previous cycles.

{}
scatter_current_kw dict

Dictionary of keywords to pass to matplotlib 'scatter' function that plots the data points from the current cycle.

{}
plot_theory_kw dict

Dictionary of keywords to pass to matplotlib 'plot' function that plots the theory line.

{}
Source code in autora/cycle/plot_utils.py
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
def plot_results_panel_2d(
    cycle: Cycle,
    iv_name: Optional[str] = None,
    dv_name: Optional[str] = None,
    steps: int = 50,
    wrap: int = 4,
    query: Optional[Union[List, slice]] = None,
    subplot_kw: dict = {},
    scatter_previous_kw: dict = {},
    scatter_current_kw: dict = {},
    plot_theory_kw: dict = {},
) -> plt.figure:
    """
    Generates a multi-panel figure with 2D plots showing results of one AER cycle.

    Observed data is plotted as a scatter plot with the current cycle colored differently than
    observed data from previous cycles. The current cycle's theory is plotted as a line over the
    range of the observed data.

    Args:
        cycle: AER Cycle object that has been run
        iv_name: Independent variable name. Name should match the name instantiated in the cycle
                    object. Default will select the first.
        dv_name: Single dependent variable name. Name should match the names instantiated in the
                    cycle object. Default will select the first DV.
        steps: Number of steps to define the condition space to plot the theory.
        wrap: Number of panels to appear in a row. Example: 9 panels with wrap=3 results in a
                3x3 grid.
        query: Query which cycles to plot with either a List of indexes or a slice. The slice must
                be constructed with the `slice()` function or `np.s_[]` index expression.
        subplot_kw: Dictionary of keywords to pass to matplotlib 'subplot' function
        scatter_previous_kw: Dictionary of keywords to pass to matplotlib 'scatter' function that
                    plots the data points from previous cycles.
        scatter_current_kw: Dictionary of keywords to pass to matplotlib 'scatter' function that
                    plots the data points from the current cycle.
        plot_theory_kw: Dictionary of keywords to pass to matplotlib 'plot' function that plots the
                    theory line.

    Returns: matplotlib figure

    """

    # ---Figure and plot params---
    # Set defaults, check and add user supplied keywords
    # Default keywords
    subplot_kw_defaults = {
        "gridspec_kw": {"bottom": 0.16},
        "sharex": True,
        "sharey": True,
    }
    scatter_previous_defaults = {
        "color": "black",
        "s": 2,
        "alpha": 0.6,
        "label": "Previous Data",
    }
    scatter_current_defaults = {
        "color": "tab:orange",
        "s": 2,
        "alpha": 0.6,
        "label": "New Data",
    }
    line_kw_defaults = {"label": "Theory"}
    # Combine default and user supplied keywords
    d_kw = {}
    for d1, d2, key in zip(
        [
            subplot_kw_defaults,
            scatter_previous_defaults,
            scatter_current_defaults,
            line_kw_defaults,
        ],
        [subplot_kw, scatter_previous_kw, scatter_current_kw, plot_theory_kw],
        ["subplot_kw", "scatter_previous_kw", "scatter_current_kw", "plot_theory_kw"],
    ):
        assert isinstance(d1, dict)
        assert isinstance(d2, dict)
        d_kw[key] = _check_replace_default_kw(d1, d2)

    # ---Extract IVs and DV metadata and indexes---
    ivs, dvs = _get_variable_index(cycle)
    if iv_name:
        iv = [s for s in ivs if s[1] == iv_name][0]
    else:
        iv = [ivs[0]][0]
    if dv_name:
        dv = [s for s in dvs if s[1] == dv_name][0]
    else:
        dv = [dvs[0]][0]
    iv_label = f"{iv[1]} {iv[2]}"
    dv_label = f"{dv[1]} {dv[2]}"

    # Create a dataframe of observed data from cycle
    df_observed = _observed_to_df(cycle)

    # Generate IV space
    condition_space = _generate_condition_space(cycle, steps=steps)

    # Get theory predictions over space
    l_predictions = _theory_predict(cycle, condition_space)

    # Cycle Indexing
    cycle_idx = list(range(len(cycle.data.theories)))
    if query:
        if isinstance(query, list):
            cycle_idx = [cycle_idx[s] for s in query]
        elif isinstance(query, slice):
            cycle_idx = cycle_idx[query]

    # Subplot configurations
    n_cycles_to_plot = len(cycle_idx)
    if n_cycles_to_plot < wrap:
        shape = (1, n_cycles_to_plot)
    else:
        shape = (int(np.ceil(n_cycles_to_plot / wrap)), wrap)

    with plt.rc_context(controller_plotting_rc_context):
        fig, axs = plt.subplots(*shape, **d_kw["subplot_kw"])
        # Place axis object in an array if plotting single panel
        if shape == (1, 1):
            axs = np.array([axs])

        # Loop by panel
        for i, ax in enumerate(axs.flat):
            if i + 1 <= n_cycles_to_plot:
                # Get index of cycle to plot
                i_cycle = cycle_idx[i]

                # ---Plot observed data---
                # Independent variable values
                x_vals = df_observed.loc[:, iv[0]]
                # Dependent values masked by current cycle vs previous data
                dv_previous = np.ma.masked_where(
                    df_observed["cycle"] >= i_cycle, df_observed[dv[0]]
                )
                dv_current = np.ma.masked_where(
                    df_observed["cycle"] != i_cycle, df_observed[dv[0]]
                )
                # Plotting scatter
                ax.scatter(x_vals, dv_previous, **d_kw["scatter_previous_kw"])
                ax.scatter(x_vals, dv_current, **d_kw["scatter_current_kw"])

                # ---Plot Theory---
                conditions = condition_space[:, iv[0]]
                ax.plot(conditions, l_predictions[i_cycle], **d_kw["plot_theory_kw"])

                # Label Panels
                ax.text(
                    0.05,
                    1,
                    f"Cycle {i_cycle}",
                    ha="left",
                    va="top",
                    transform=ax.transAxes,
                )

            else:
                ax.axis("off")

        # Super Labels
        fig.supxlabel(iv_label, y=0.07)
        fig.supylabel(dv_label)

        # Legend
        fig.legend(
            ["Previous Data", "New Data", "Theory"],
            ncols=3,
            bbox_to_anchor=(0.5, 0),
            loc="lower center",
        )

    return fig

plot_results_panel_3d(cycle, iv_names=None, dv_name=None, steps=50, wrap=4, view=None, subplot_kw={}, scatter_previous_kw={}, scatter_current_kw={}, surface_kw={})

Generates a multi-panel figure with 3D plots showing results of one AER cycle.

Observed data is plotted as a scatter plot with the current cycle colored differently than observed data from previous cycles. The current cycle's theory is plotted as a line over the range of the observed data.

Parameters:

Name Type Description Default
cycle Cycle

AER Cycle object that has been run

required
iv_names Optional[List[str]]

List of up to 2 independent variable names. Names should match the names instantiated in the cycle object. Default will select up to the first two.

None
dv_name Optional[str]

Single DV name. Name should match the names instantiated in the cycle object. Default will select the first DV

None
steps int

Number of steps to define the condition space to plot the theory.

50
wrap int

Number of panels to appear in a row. Example: 9 panels with wrap=3 results in a 3x3 grid.

4
view Optional[Tuple[float, float]]

Tuple of elevation angle and azimuth to change the viewing angle of the plot.

None
subplot_kw dict

Dictionary of keywords to pass to matplotlib 'subplot' function

{}
scatter_previous_kw dict

Dictionary of keywords to pass to matplotlib 'scatter' function that plots the data points from previous cycles.

{}
scatter_current_kw dict

Dictionary of keywords to pass to matplotlib 'scatter' function that plots the data points from the current cycle.

{}
surface_kw dict

Dictionary of keywords to pass to matplotlib 'plot_surface' function that plots the theory plane.

{}
Source code in autora/cycle/plot_utils.py
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
def plot_results_panel_3d(
    cycle: Cycle,
    iv_names: Optional[List[str]] = None,
    dv_name: Optional[str] = None,
    steps: int = 50,
    wrap: int = 4,
    view: Optional[Tuple[float, float]] = None,
    subplot_kw: dict = {},
    scatter_previous_kw: dict = {},
    scatter_current_kw: dict = {},
    surface_kw: dict = {},
) -> plt.figure:
    """
    Generates a multi-panel figure with 3D plots showing results of one AER cycle.

    Observed data is plotted as a scatter plot with the current cycle colored differently than
    observed data from previous cycles. The current cycle's theory is plotted as a line over the
    range of the observed data.

    Args:

        cycle: AER Cycle object that has been run
        iv_names: List of up to 2 independent variable names. Names should match the names
                    instantiated in the cycle object. Default will select up to the first two.
        dv_name: Single DV name. Name should match the names instantiated in the cycle object.
                    Default will select the first DV
        steps: Number of steps to define the condition space to plot the theory.
        wrap: Number of panels to appear in a row. Example: 9 panels with wrap=3 results in a
                3x3 grid.
        view: Tuple of elevation angle and azimuth to change the viewing angle of the plot.
        subplot_kw: Dictionary of keywords to pass to matplotlib 'subplot' function
        scatter_previous_kw: Dictionary of keywords to pass to matplotlib 'scatter' function that
                    plots the data points from previous cycles.
        scatter_current_kw: Dictionary of keywords to pass to matplotlib 'scatter' function that
                    plots the data points from the current cycle.
        surface_kw: Dictionary of keywords to pass to matplotlib 'plot_surface' function that plots
                    the theory plane.

    Returns: matplotlib figure

    """
    n_cycles = len(cycle.data.theories)

    # ---Figure and plot params---
    # Set defaults, check and add user supplied keywords
    # Default keywords
    subplot_kw_defaults = {
        "subplot_kw": {"projection": "3d"},
    }
    scatter_previous_defaults = {"color": "black", "s": 2, "label": "Previous Data"}
    scatter_current_defaults = {"color": "tab:orange", "s": 2, "label": "New Data"}
    surface_kw_defaults = {"alpha": 0.5, "label": "Theory"}
    # Combine default and user supplied keywords
    d_kw = {}
    for d1, d2, key in zip(
        [
            subplot_kw_defaults,
            scatter_previous_defaults,
            scatter_current_defaults,
            surface_kw_defaults,
        ],
        [subplot_kw, scatter_previous_kw, scatter_current_kw, surface_kw],
        ["subplot_kw", "scatter_previous_kw", "scatter_current_kw", "surface_kw"],
    ):
        assert isinstance(d1, dict)
        assert isinstance(d2, dict)
        d_kw[key] = _check_replace_default_kw(d1, d2)

    # ---Extract IVs and DV metadata and indexes---
    ivs, dvs = _get_variable_index(cycle)
    if iv_names:
        iv = [s for s in ivs if s[1] == iv_names]
    else:
        iv = ivs[:2]
    if dv_name:
        dv = [s for s in dvs if s[1] == dv_name][0]
    else:
        dv = [dvs[0]][0]
    iv_labels = [f"{s[1]} {s[2]}" for s in iv]
    dv_label = f"{dv[1]} {dv[2]}"

    # Create a dataframe of observed data from cycle
    df_observed = _observed_to_df(cycle)

    # Generate IV Mesh Grid
    x1, x2 = _generate_mesh_grid(cycle, steps=steps)

    # Get theory predictions over space
    l_predictions = _theory_predict(cycle, np.column_stack((x1.ravel(), x2.ravel())))

    # Subplot configurations
    if n_cycles < wrap:
        shape = (1, n_cycles)
    else:
        shape = (int(np.ceil(n_cycles / wrap)), wrap)
    with plt.rc_context(controller_plotting_rc_context):
        fig, axs = plt.subplots(*shape, **d_kw["subplot_kw"])

        # Loop by panel
        for i, ax in enumerate(axs.flat):
            if i + 1 <= n_cycles:

                # ---Plot observed data---
                # Independent variable values
                l_x = [df_observed.loc[:, s[0]] for s in iv]
                # Dependent values masked by current cycle vs previous data
                dv_previous = np.ma.masked_where(
                    df_observed["cycle"] >= i, df_observed[dv[0]]
                )
                dv_current = np.ma.masked_where(
                    df_observed["cycle"] != i, df_observed[dv[0]]
                )
                # Plotting scatter
                ax.scatter(*l_x, dv_previous, **d_kw["scatter_previous_kw"])
                ax.scatter(*l_x, dv_current, **d_kw["scatter_current_kw"])

                # ---Plot Theory---
                ax.plot_surface(
                    x1, x2, l_predictions[i].reshape(x1.shape), **d_kw["surface_kw"]
                )
                # ---Labels---
                # Title
                ax.set_title(f"Cycle {i}")

                # Axis
                ax.set_xlabel(iv_labels[0])
                ax.set_ylabel(iv_labels[1])
                ax.set_zlabel(dv_label)

                # Viewing angle
                if view:
                    ax.view_init(*view)

            else:
                ax.axis("off")

        # Legend
        handles, labels = axs.flatten()[0].get_legend_handles_labels()
        legend_elements = [
            handles[0],
            handles[1],
            Patch(facecolor=handles[2].get_facecolors()[0]),
        ]
        fig.legend(
            handles=legend_elements,
            labels=labels,
            ncols=3,
            bbox_to_anchor=(0.5, 0),
            loc="lower center",
        )

    return fig