Augmenting EEG with Generative Adversarial Networks
Created by Daniel Weinhardt, Chad Williams, & Sebastian Musslick
We here use Generative Adversarial Networks (GANs) to create trial-level synthetic EEG samples. We can then use these samples as extra data to train whichever classifier we want to use (e.g., Support Vector Machine, Neural Network).
GANs are machine learning frameworks that consist of two adversarial neural network agents, namely the generator and the discriminator. The generator is trained to create novel samples that are indiscernible from real samples. In the current context, the generator produces realistic continuous EEG activity, conditioned on a set of experimental variables, which contain underlying neural features representative of the outcomes being classified. For example, depression manifests as increased alpha oscillatory activity in the EEG signal, and thus, an ideal generator would produce continuous EEG that includes these alpha signatures. In contrast to the generator, the discriminator determines whether a given sample is real or synthetically produced by the generator. The core insight of GANs is that the generator can effectively learn from the discriminator. Specifically, the generator will consecutively produce more realistic synthetic samples with the goal of “fooling” the discriminator into believing them as real. Once it has achieved realistic samples that the discriminator cannot discern, it can be used to generate synthetic data—or in this context, synthetic EEG data.
The dataset provided is a subset of data from Williams et al., 2021 (Psychophysics). In this study, participants completed a two-armed bandit gambling task where they needed to discern which of two coloured squares were more often rewarding through trial-and-error. Each trial presented two coloured squares that the participants were to choose from, and provided performance feedback as “WIN” or “LOSE”, yielding two conditions of interest, cwin, close. For each pair of squares, one had a win rate of 60% while the other had a win rate of 10%. Participants saw each pair of colours twenty times consecutively. There were a total of five pairs of squares (with colours randomly determined), resulting in one hundred trials per participant. This paradigm elicits well-known frontal neural differences when contrasting the win and lose outcomes, namely in the reward positivity, delta oscillations, and theta oscillations (see Williams et al., 2021; Psychophysics).
In this tutorial, we will classify the WIN and LOSE conditions using both Support Vector Machine and Neural Network classifiers. We will:
- Train a GAN on trial-level EEG data
- Generate synthetic EEG data
- Create an augmented EEG dataset
- Determine classification performance using both the empirical and augmented datasets
- Empirical Dataset: We train the classifer on the empirical data that was used to train the GANs
- Augmented Dataset: We train the classifer on the empirical data with the appended synthetic samples
Evaluation of EEG-GAN¶
Augmenting EEG with Generative Adversarial Networks Enhances Brain Decoding Across Classifiers and Sample Sizes
$Williams^{*1}$, $Weinhardt^{*2}$, $Wirzberger^{2}$, & $Musslick^{1}$ (submitted, 2023)
**Co-First Authors
$^{1}$ *Brainstorm, Carney Institute for Brain Science, Brown University
$^{2}$ University of Stuttgart
Table of Contents¶
Step 0. Installing and Loading Modules
Step 0.1. Installing Modules
Step 0.2. Loading Modules
Step 1. EEG Data
Step 1.1. Load Data
Step 1.2. View Data
Step 2. GAN
Step 2.1. Exploring the Main GAN Package Functions
Step 2.1.1. GAN Training Help
Step 2.1.2. Visualize Help
Step 2.1.3. Generate Samples Help
Step 2.2. Training the GAN
Step 2.3. Visualizing GAN Losses
Step 2.4. Generating Synthetic Data
Step 3. Synthetic Data
Step 3.1. Load Data
Step 3.2. View Data
Step 3.2.1. View Trial-Level Data
Step 3.2.2. View ERP Data
Step 4. Classification Setup
Step 4.1. Preparing Validation Data
Step 4.2. Preparing Empirical Data
Step 4.3. Preparing Augmented Data
Step 5. Support Vector Machine
Step 5.1. Define Search Space
Step 5.2. Classify Empirical Data
Step 5.3. Classify Augmented Data
Step 6. Neural Network
Step 6.1. Define Search Space
Step 6.2. Classify Empirical Data
Step 6.3. Classify Augmented Data
Step 7. Final Report
Step 7.1. Present Classification Performance
Step 7.2. Plot Classification Performance
Note: you can also view an interactive table of contents in your sidebar
Step 0. Installing and Loading Modules¶
Step 0.1. Installing Modules¶
We will now download and install the EEG-GAN package
#%%capture
!pip install eeggan
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/ Collecting eeggan Downloading eeggan-0.0.22-py3-none-any.whl (18.0 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 18.0/18.0 MB 26.1 MB/s eta 0:00:00 Collecting torchvision~=0.13.1 Downloading torchvision-0.13.1-cp39-cp39-manylinux1_x86_64.whl (19.1 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 19.1/19.1 MB 47.3 MB/s eta 0:00:00 Collecting pandas~=1.3.4 Downloading pandas-1.3.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (11.5 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 11.5/11.5 MB 61.0 MB/s eta 0:00:00 Collecting torchaudio~=0.12.1 Downloading torchaudio-0.12.1-cp39-cp39-manylinux1_x86_64.whl (3.7 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.7/3.7 MB 16.3 MB/s eta 0:00:00 Collecting scikit-learn~=1.1.2 Downloading scikit_learn-1.1.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30.8 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 30.8/30.8 MB 14.2 MB/s eta 0:00:00 Collecting numpy~=1.23.1 Downloading numpy-1.23.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.1 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 17.1/17.1 MB 29.0 MB/s eta 0:00:00 Collecting torchtext~=0.13.1 Downloading torchtext-0.13.1-cp39-cp39-manylinux1_x86_64.whl (1.9 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.9/1.9 MB 26.9 MB/s eta 0:00:00 Collecting einops~=0.4.1 Downloading einops-0.4.1-py3-none-any.whl (28 kB) Collecting torch~=1.12.1 Downloading torch-1.12.1-cp39-cp39-manylinux1_x86_64.whl (776.4 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 776.4/776.4 MB 1.4 MB/s eta 0:00:00 Collecting scipy~=1.8.0 Downloading scipy-1.8.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (42.2 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 42.2/42.2 MB 19.2 MB/s eta 0:00:00 Requirement already satisfied: torchsummary~=1.5.1 in /usr/local/lib/python3.9/dist-packages (from eeggan) (1.5.1) Collecting matplotlib~=3.5.0 Downloading matplotlib-3.5.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl (11.2 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 11.2/11.2 MB 79.3 MB/s eta 0:00:00 Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.9/dist-packages (from matplotlib~=3.5.0->eeggan) (2.8.2) Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.9/dist-packages (from matplotlib~=3.5.0->eeggan) (1.4.4) Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.9/dist-packages (from matplotlib~=3.5.0->eeggan) (8.4.0) Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.9/dist-packages (from matplotlib~=3.5.0->eeggan) (23.0) Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.9/dist-packages (from matplotlib~=3.5.0->eeggan) (4.39.3) Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.9/dist-packages (from matplotlib~=3.5.0->eeggan) (0.11.0) Requirement already satisfied: pyparsing>=2.2.1 in /usr/local/lib/python3.9/dist-packages (from matplotlib~=3.5.0->eeggan) (3.0.9) Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.9/dist-packages (from pandas~=1.3.4->eeggan) (2022.7.1) Requirement already satisfied: joblib>=1.0.0 in /usr/local/lib/python3.9/dist-packages (from scikit-learn~=1.1.2->eeggan) (1.2.0) Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.9/dist-packages (from scikit-learn~=1.1.2->eeggan) (3.1.0) Requirement already satisfied: typing-extensions in /usr/local/lib/python3.9/dist-packages (from torch~=1.12.1->eeggan) (4.5.0) Requirement already satisfied: tqdm in /usr/local/lib/python3.9/dist-packages (from torchtext~=0.13.1->eeggan) (4.65.0) Requirement already satisfied: requests in /usr/local/lib/python3.9/dist-packages (from torchtext~=0.13.1->eeggan) (2.27.1) Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.9/dist-packages (from python-dateutil>=2.7->matplotlib~=3.5.0->eeggan) (1.16.0) Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.9/dist-packages (from requests->torchtext~=0.13.1->eeggan) (1.26.15) Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.9/dist-packages (from requests->torchtext~=0.13.1->eeggan) (2.0.12) Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.9/dist-packages (from requests->torchtext~=0.13.1->eeggan) (3.4) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.9/dist-packages (from requests->torchtext~=0.13.1->eeggan) (2022.12.7) Installing collected packages: einops, torch, numpy, torchvision, torchtext, torchaudio, scipy, pandas, matplotlib, scikit-learn, eeggan Attempting uninstall: torch Found existing installation: torch 2.0.0+cu118 Uninstalling torch-2.0.0+cu118: Successfully uninstalled torch-2.0.0+cu118 Attempting uninstall: numpy Found existing installation: numpy 1.22.4 Uninstalling numpy-1.22.4: Successfully uninstalled numpy-1.22.4 Attempting uninstall: torchvision Found existing installation: torchvision 0.15.1+cu118 Uninstalling torchvision-0.15.1+cu118: Successfully uninstalled torchvision-0.15.1+cu118 Attempting uninstall: torchtext Found existing installation: torchtext 0.15.1 Uninstalling torchtext-0.15.1: Successfully uninstalled torchtext-0.15.1 Attempting uninstall: torchaudio Found existing installation: torchaudio 2.0.1+cu118 Uninstalling torchaudio-2.0.1+cu118: Successfully uninstalled torchaudio-2.0.1+cu118 Attempting uninstall: scipy Found existing installation: scipy 1.10.1 Uninstalling scipy-1.10.1: Successfully uninstalled scipy-1.10.1 Attempting uninstall: pandas Found existing installation: pandas 1.5.3 Uninstalling pandas-1.5.3: Successfully uninstalled pandas-1.5.3 Attempting uninstall: matplotlib Found existing installation: matplotlib 3.7.1 Uninstalling matplotlib-3.7.1: Successfully uninstalled matplotlib-3.7.1 Attempting uninstall: scikit-learn Found existing installation: scikit-learn 1.2.2 Uninstalling scikit-learn-1.2.2: Successfully uninstalled scikit-learn-1.2.2 ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts. torchdata 0.6.0 requires torch==2.0.0, but you have torch 1.12.1 which is incompatible. google-colab 1.0.0 requires pandas~=1.5.3, but you have pandas 1.3.5 which is incompatible. Successfully installed eeggan-0.0.22 einops-0.4.1 matplotlib-3.5.3 numpy-1.23.5 pandas-1.3.5 scikit-learn-1.1.3 scipy-1.8.1 torch-1.12.1 torchaudio-0.12.1 torchtext-0.13.1 torchvision-0.13.1
Step 0.2. Loading Modules¶
#Load EEG-GAN module
from eeggan import train_gan, visualize_gan, generate_samples, setup_tutorial
#Load other modules specific to this notebook
import numpy as np
import matplotlib.pyplot as plt
import shutil
import os
import random as rnd
from scipy import signal
from sklearn.preprocessing import scale
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.svm import SVC
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import classification_report
import torch
#Create a print formatting class
class printFormat:
bold = '\033[1m'
italic = '\033[3m'
end = '\033[0m'
#Setup
#This function downloads tutorial-required files (e.g., datasets) from the GitHub. These files can also be found within the package itself, but Google Colab has difficulty accessing it.
#This function is only necessary when running the tutorial but it also creates three folders (data, trained_models, generated_samples) that are needed for the package, so still may be useful with your own data.
setup_tutorial()
Step 1. EEG Data¶
Step 1.1. Load Data¶
We will load the provided EEG training data and print some information about what this contains.
#Load the data
empiricalHeaders = np.genfromtxt('data/gansEEGTrainingData.csv', delimiter=',', names=True).dtype.names
empiricalEEG = np.genfromtxt('data/gansEEGTrainingData.csv', delimiter=',', skip_header=1)
#Print the head of the data
print(printFormat.bold + 'Display Header and first few rows/columns of data\n \033[0m' + printFormat.end)
print(empiricalHeaders[:6])
print(empiricalEEG[0:3,:6])
#Print some information about the columns
print('\n------------------------------------------------------------------------------------------')
print(printFormat.bold + '\nNote the first three columns:' + printFormat.end +'\n ParticipantID - Indicates different participants\n Condition - Indicates the condition (WIN = 0, LOSE = 1) to be classified\n Trial - Indicates the trial number for that participant and condition')
print('\nThe remaining columns are titled Time1 to Time100 - indicating 100 datapoints per sample.\nThe samples span from -200 to 1000ms around the onset of a feedback stimulus.\nThese are downsampled from the original data, which contained 600 datapoints per sample.')
#Print some meta-data
print('\n------------------------------------------------------------------------------------------')
print('\n' + printFormat.bold + 'Other characteristics of our data include:' + printFormat.end)
print('-We have ' + str(len(set(empiricalEEG[:,0]))) + ' participants in our training set')
print('-Participants have an average of ' + str(round(np.mean([np.max(empiricalEEG[empiricalEEG[:,0]==pID,2]) for pID in set(empiricalEEG[:,0])]))) + ' (SD: ' + str(round(np.std([np.max(empiricalEEG[empiricalEEG[:,0]==pID,2]) for pID in set(empiricalEEG[:,0])]))) + ')' + ' trials per outcome (win, lose)')
Display Header and first few rows/columns of data ('ParticipantID', 'Condition', 'Trial', 'Time1', 'Time2', 'Time3') [[11. 0. 1. 2.287618 -4.448947 -0.980726] [11. 0. 2. 11.672745 0.669828 1.43829 ] [11. 0. 3. 10.510648 11.908067 8.671387]] ------------------------------------------------------------------------------------------ Note the first three columns: ParticipantID - Indicates different participants Condition - Indicates the condition (WIN = 0, LOSE = 1) to be classified Trial - Indicates the trial number for that participant and condition The remaining columns are titled Time1 to Time100 - indicating 100 datapoints per sample. The samples span from -200 to 1000ms around the onset of a feedback stimulus. These are downsampled from the original data, which contained 600 datapoints per sample. ------------------------------------------------------------------------------------------ Other characteristics of our data include: -We have 15 participants in our training set -Participants have an average of 46 (SD: 7) trials per outcome (win, lose)
Step 1.2. View Data¶
Let's view the grand-averaged ERPs of our 15 participants.
#Determine which rows are each condition
lossIndex = np.where(empiricalEEG[:,1]==1)
winIndex = np.where(empiricalEEG[:,1]==0)
#Grand average the waveforms for each condition
lossWaveform = np.mean(empiricalEEG[lossIndex,3:],axis=1)[0]
winWaveform = np.mean(empiricalEEG[winIndex,3:],axis=1)[0]
#Determine x axis of time
time = np.linspace(-200,1000,100)
#Setup figure
f, (ax1) = plt.subplots(1, 1, figsize=(6, 4))
#Plot each waveform
ax1.plot(time, lossWaveform, label = 'Loss')
ax1.plot(time, winWaveform, label = 'Win')
#Format plot
ax1.set_ylabel('Voltage ($\mu$V)')
ax1.set_xlabel('Time (ms)')
ax1.set_title('Empirical', loc='left')
ax1.spines[['right', 'top']].set_visible(False)
ax1.legend(frameon=False)
<matplotlib.legend.Legend at 0x7fd6189a7070>
Step 2. GAN¶
Step 2.1. Exploring the Main GAN Package Functions¶
Functions
We will be using three functions from the GANs package:
train_gan()
- This trains a GAN
visualize_gan()
- This visualizes components of a trained GAN, such as the training losses
generate_samples()
- This generates synthetic samples using the trained GAN
Arguments
Each function can take a single argument argv
, which should be a dictionary:
argv = dict(
path_dataset=data\my_data.csv,
n_epochs = 100
)
train_gan(argv)
Help
You can use the help argument to see a list of possible arguments with a brief description:
train_gan(dict(help = True))
visualize_gan(dict(help = True))
generate_samples(dict(help = True))
Step 2.1.1. GAN Training Help¶
train_gan(dict(help = True))
----------------------------------------- Command line arguments: ----------------------------------------- ---------------------------------------------------------------------------------------------------------------------------------------------------- INPUT HELP - These are the inputs that can be given from the command line ---------------------------------------------------------------------------------------------------------------------------------------------------- Input | Type | Description | Default value ---------------------------------------------------------------------------------------------------------------------------------------------------- ddp | <class 'bool'> | Activate distributed training | False load_checkpoint | <class 'bool'> | Load a pre-trained GAN | False train_gan | <class 'bool'> | Train a GAN | True filter_generator | <class 'bool'> | Use low-pass filter on the generator output | False windows_slices | <class 'bool'> | Use sliding windows instead of whole sequences | False n_epochs | <class 'int'> | Number of epochs | 100 batch_size | <class 'int'> | Batch size | 128 patch_size | <class 'int'> | Patch size | 20 sequence_length | <class 'int'> | Used length of the datasets sequences; If None, then the whole sequence is used | -1 seq_len_generated | <class 'int'> | Length of the generated sequence | -1 sample_interval | <class 'int'> | Interval of epochs between saving samples | 10 learning_rate | <class 'float'> | Learning rate of the GAN | 0.0001 path_dataset | <class 'str'> | Path to the dataset | data/gansEEGTrainingData.csv path_checkpoint | <class 'str'> | Path to the checkpoint | trained_models/checkpoint.pt ddp_backend | <class 'str'> | Backend for the DDP-Training; "nccl" for GPU; "gloo" for CPU; | nccl conditions | <class 'str'> | ** Conditions to be used | Condition kw_timestep_dataset | <class 'str'> | Keyword for the time step of the dataset | Time ---------------------------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------------------------------------------------------- QUICK HELP - These are the special features: ---------------------------------------------------------------------------------------------------------------------------------------------------- General information: Boolean arguments are given as a single keyword: Set boolean keyword "test_keyword" to True -> python file.py test_keyword Command line arguments are given as a keyword followed by an equal sign and the value: Set command line argument "test_keyword" to "test_value" -> python file.py test_keyword=test_value Whitespaces are not allowed between a keyword and its value. Some keywords can be given list-like: test_keyword=test_value1,test_value2 These keywords are marked with ** in the table. ---------------------------------------------------------------------------------------------------------------------------------------------------- 1. The training works with two levels of checkpoint files: 1.1 During the training: Checkpoints are saved every "sample_interval" batches as either "checkpoint_01.pt" or "checkpoint_02.pt". These checkpoints are considered as low-level checkpoints since they are only necessary in the case of training interruption. Hereby, they can be used to continue the training from the most recent sample. To continue training, the most recent checkpoint file must be renamed to "checkpoint.pt". Further, these low-level checkpoints carry the generated samples for inference purposes. 1.2 After finishing the training: A high-level checkpoint is saved as "checkpoint.pt", which is used to continue training in another session. This high-level checkpoint does not carry the generated samples. To continue training from this checkpoint file no further adjustments are necessary. Simply give the keyword "load_checkpoint" when calling the training process. The low-level checkpoints are deleted after creating the high-level checkpoint. 1.3 For inference purposes: Another dictionary is saved as "gan_{n_epochs}ep_{timestamp}.pt". This file contains everything the checkpoint file contains, plus the generated samples. 2. Use "ddp=True" to activate distributed training. Only if multiple GPUs are available for one node. All available GPUs are used for training. Each GPUs trains on the whole dataset. Hence, the number of training epochs is multiplied by the number of GPUs 3. If you want to load a pre-trained GAN, you can use the following command: python train_gan.py load_checkpoint; The default file is "trained_models/checkpoint.pt" If you want to use an other file, you can use the following command: python train_gan.py load_checkpoint path_checkpoint="path/to/file.pt" 4. If you want to use a different dataset, you can use the following command: python train_gan.py path_dataset="path/to/file.csv" The default dataset is "data/ganAverageERP.csv" 5. The keyword "sequence_length" has two different meanings based on the keyword "windows_slices": 5.1 "windows_slices" is set to "False": The keyword "sequence_length" defines the length of the taken sequence from the dataset. Hereby, only the first {sequence_length} data points are taken from each sample. The default value is -1, which means that the whole sequence is taken. 5.2 "windows_slices" is set to "True": The keyword "sequence_length" defines the length of a single window taken from the dataset. Hereby, a sample from the dataset is sliced into windows of length "sequence_length". Each window is then used as a single sample. The samples are taken by moving the window with a specific stride (=5) over the samples. 5. Have in mind to change the keyword patch_size if you use another value for the keyword sequence_length. The condition sequence_length % patch_size == 0 must be fulfilled. Otherwise the sequence will be padded with zeros until the condition is fulfilled. 6. The keyword "seq_len_generated" describes the length of the generated sequences. 6.1 The condition "seq_len_generated" <= "sequence_length" must be fulfilled. 6.2 The generator works in the following manner: The generator gets a sequence of length ("sequence_length"-"seq_len_generated") as a condition (input). The generator generates a sequence of length "seq_len_generated" as output which is used as the subsequent part of the input sequence. 6.3 If ("seq_len_generated" == "sequence_length"): The generator does not get any input sequence but generates an arbitrary sequence of length "sequence_length". Arbitrary means hereby that the generator does not get any conditions on previous data points. ---------------------------------------------------------------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------------------------------------------------------------
[]
Step 2.1.2. Visualize Help¶
visualize_gan(dict(help = True))
----------------------------------------- Command line arguments: ----------------------------------------- ---------------------------------------------------------------------------------------------------------------------------------------------------- INPUT HELP - These are the inputs that can be given from the command line ---------------------------------------------------------------------------------------------------------------------------------------------------- Input | Type | Description | Default value ---------------------------------------------------------------------------------------------------------------------------------------------------- file | <class 'str'> | File to be used | trained_models/checkpoint.pt training_file | <class 'str'> | Path to the original data | data/ganAverageERP.csv kw_timestep_dataset | <class 'str'> | Keyword for the time step of the dataset | Time conditions | <class 'str'> | ** Conditions to be used | Condition checkpoint | <class 'bool'> | Use samples from training checkpoint file | False experiment | <class 'bool'> | Use samples from experimental data | False csv_file | <class 'bool'> | Use samples from csv-file | False plot_losses | <class 'bool'> | Plot training losses | False averaged | <class 'bool'> | Average over all samples to get one averaged curve | False pca | <class 'bool'> | Use PCA to reduce the dimensionality of the data | False tsne | <class 'bool'> | Use t-SNE to reduce the dimensionality of the data | False spectogram | <class 'bool'> | Use spectogram to visualize the frequency distribution of the data | False fft_hist | <class 'bool'> | Use a FFT-histogram to visualize the frequency distribution of the data | False save | <class 'bool'> | Save the generated plots in the directory "plots" instead of showing them | False bandpass | <class 'bool'> | Use bandpass filter from models.TtsGeneratorFiltered.filter() on samples | False n_conditions | <class 'int'> | Number of conditions as first columns in data | 1 n_samples | <class 'int'> | Total number of samples to be plotted | 10 batch_size | <class 'int'> | Number of samples in one plot | 10 starting_row | <class 'int'> | Starting row of the dataset | 0 tsne_perplexity | <class 'int'> | Perplexity of t-SNE | 40 tsne_iterations | <class 'int'> | Number of iterations of t-SNE | 1000 ---------------------------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------------------------------------------------------- QUICK HELP - These are the special features: ---------------------------------------------------------------------------------------------------------------------------------------------------- General information: Boolean arguments are given as a single keyword: Set boolean keyword "test_keyword" to True -> python file.py test_keyword Command line arguments are given as a keyword followed by an equal sign and the value: Set command line argument "test_keyword" to "test_value" -> python file.py test_keyword=test_value Whitespaces are not allowed between a keyword and its value. Some keywords can be given list-like: test_keyword=test_value1,test_value2 These keywords are marked with ** in the table. ---------------------------------------------------------------------------------------------------------------------------------------------------- 1. The keyword "file" carries some special features: 1.1 It is possible to give only a file instead of a whole file path. In this case, the default path is specified regarding the following keywords: "checkpoint" -> path = "trained_models" "experiment" -> path = "data" "csv_file" -> path = "generated_samples" 1.2 Specification of the keyword "file": The default file works only in combination with the keyword "checkpoint". In any other case, the default file must be specified with a compatible file name. 2. If the keyword "starting_row" is given, the dataset will start from the given row. This utility is useful to skip early training stage samples. The value can also be negative to specify the last n entries e.g. "starting_row=-100": The last 100 samples of the dataset are used. 4. The keyword "plot_losses" works only with the keyword "checkpoint". 5. When using the keywords "pca" or "tsne" the "training_file" can be defined. The legend corresponds to the plotted samples according to: red -> samples from "training_file" blue -> samples from "file" Therefore, the blue samples can also correspond to samples from an experiment file. ----------------------------------------------------------------------------------------------------------------------------------------------------
[]
Step 2.1.3. Generate Samples Help¶
generate_samples(dict(help = True))
----------------------------------------- Command line arguments: ----------------------------------------- ---------------------------------------------------------------------------------------------------------------------------------------------------- INPUT HELP - These are the inputs that can be given from the command line ---------------------------------------------------------------------------------------------------------------------------------------------------- Input | Type | Description | Default value ---------------------------------------------------------------------------------------------------------------------------------------------------- file | <class 'str'> | File which contains the trained model and its configuration | trained_models/checkpoint.pt path_samples | <class 'str'> | File where to store the generated samples; If None, then checkpoint name is used | None kw_timestep_dataset | <class 'str'> | Keyword for the time step of the dataset; to determine the sequence length | Time sequence_length_total | <class 'int'> | total sequence length of generated sample; if -1, then sequence length from training dataset | -1 num_samples_total | <class 'int'> | total number of generated samples | 1000 num_samples_parallel | <class 'int'> | number of samples generated in parallel | 50 conditions | <class 'int'> | ** Specific condition; -1 -> random condition (only for binary condition) | -1 average | <class 'int'> | Average over n latent variables to get an averaged one | 1 all_cond_per_z | <class 'bool'> | PRELIMINARY; ONLY FOR SINGLE BINARY CONDITION; Generate all conditions per latent variable | False ---------------------------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------------------------------------------------------- QUICK HELP - These are the special features: ---------------------------------------------------------------------------------------------------------------------------------------------------- General information: Boolean arguments are given as a single keyword: Set boolean keyword "test_keyword" to True -> python file.py test_keyword Command line arguments are given as a keyword followed by an equal sign and the value: Set command line argument "test_keyword" to "test_value" -> python file.py test_keyword=test_value Whitespaces are not allowed between a keyword and its value. Some keywords can be given list-like: test_keyword=test_value1,test_value2 These keywords are marked with ** in the table. ---------------------------------------------------------------------------------------------------------------------------------------------------- 1. The keyword "file" carries some special features: 1.1 It is possible to give only a file instead of a whole file path In this case, the default path is "trained_models" 1.2 The specified file must be a checkpoint file which contains the generator state dict and its corresponding configuration dict 2. The keyword "sequence_length_total" defines the length of the generated sequences The default value is -1, which means that the max sequence length is chosen The max sequence length is determined by the used training data set given by the configuration dict 3. The keyword "condition" defines the condition for the generator: 3.1 Hereby, the value can be either a scalar or a comma-seperated list of scalars e.g., "condition=1,3.234,0" Current implementation: The single elements must be numeric The length of the condition must be equal to the "n_condition" parameter in the configuration dict 3.2 The value -1 means that the condition is chosen randomly This works currently only for binary conditions. 4. The keyword "num_samples_parallel" defines the number of generated samples in one batch This parameter should be set according to the processing power of the used machine Especially, the generation of large number of sequences can be boosted by increasing this parameter
[]
Step 2.2. Training the GAN¶
To train the GAN in this tutorial, we will be using the following arguments:
- path_dataset=data/gansEEGTrainingData.csv : Determines the training dataset
- n_epochs=5 : Determines number of times to train the GAN
-Here we only use 5 epochs to demonstrate the process but this will result in a very under-trained GAN. In the manuscript, we trained for 8000 epochs.
Note: If the ddp argument is provided, GANs will be trained using GPUs rather than CPUs
#Train the GAN on CPUs
argv = dict(
path_dataset="data/gansEEGTrainingData.csv",
n_epochs=5
)
train_gan(argv)
#Train the GAN on GPUs
#Note, on Google Colab you can start a GPU runtime by going to Runtime > Change runtime type > Hardware accelerator > GPU
'''
argv = dict(
ddp = True,
path_dataset="data/gansEEGTrainingData.csv",
n_epochs=5
)
train_gan(argv)
'''
----------------------------------------- Command line arguments: ----------------------------------------- Dataset: data/gansEEGTrainingData.csv Number of epochs: 5 ----------------------------------------- System output: ----------------------------------------- Generator and discriminator initialized. ----------------------------------------- Training GAN... ----------------------------------------- [Epoch 1/5] [D loss: 8.342270] [G loss: 0.171996] [Epoch 2/5] [D loss: 7.199574] [G loss: 0.587039] [Epoch 3/5] [D loss: 5.645170] [G loss: 1.041533] [Epoch 4/5] [D loss: 4.101204] [G loss: 1.683151] [Epoch 5/5] [D loss: 2.708827] [G loss: 2.192675] Managing checkpoints... GAN training finished. Generated samples saved to file. Model states saved to file.
'\nargv = dict(\n ddp = True,\n path_dataset="data/gansEEGTrainingData.csv",\n n_epochs=5\n)\n\ntrain_gan(argv)\n'
Step 2.3. Visualizing GAN Losses¶
The GAN trains for the number of epochs specified above; however, this does not ensure that it will train successfully. So, it is important to visualize our training success and ensure that it completed successfully. If it did, we can move forward with using the GAN, but if it did not then we would need to continue training the GAN. This latter case is not a problem though because the package was built so that you can continue training a previously trained GAN (rather than having to start over) if you use the load_checkpoint
and path_checkpoint
arguments with the train_gan.py
function.
We will now visualize the generator and discriminator losses using the following arguments:
- plot_losses : Determines that we will be viewing the losses
- checkpoint : Specifies that we are visualizing a GAN
- file=ganEEGModel.pt : Determines which GAN to visualize
- training_file=data\gansEEGTrainingData.csv : Points towards the data used to train the GAN
We will know that training was successful if both the generator and discriminator losses hover around 0 at the end of training.
#We trained our GAN for 5 epochs and this will result in a model that is severally under-trained, so we will instead use a pre-trained GAN that trained for 8000 epochs:
argv = dict(
plot_losses = True,
checkpoint = True,
file = "gansEEGModel.pt",
training_file = "data\gansEEGTrainingData.csv"
)
visualize_gan(argv)
#The GAN training results fom the last step results in a file named checkpoint.pt. If you want to continue with this file, use the following line of code:
'''
argv = dict(
plot_losses = True,
checkpoint = True,
file = "checkpoint.pt",
training_file = "data\gansEEGTrainingData.csv"
)
visualize_gan(argv)
'''
----------------------------------------- Command line arguments: ----------------------------------------- Plotting training losses Using samples from checkpoint file File: gansEEGModel.pt Training dataset: data\gansEEGTrainingData.csv ----------------------------------------- System output: -----------------------------------------
'\nargv = dict(\n plot_losses = True,\n checkpoint = True,\n file = "checkpoint.pt",\n training_file = "data\\gansEEGTrainingData.csv"\n)\n\nvisualize_gan(argv)\n'
Step 2.4. Generating Synthetic Data¶
We will be using the following arguments:
- file=gansEEGModel.pt : Determines which model to use
The default trained GAN name is checkpoint.pt but we will instead use a pre-trained GAN named gansEEGModel.pt - path_samples=gansEEGSyntheticData.csv : Where and what to save the generated samples as
- num_samples_total=10000 : Number of samples to generate (half per condition)
#We trained our GAN for 5 epochs and this will result in a model that is severally under-trained, so we will instead use a pre-trained GAN that trained for 8000 epochs:
argv = dict(
file = "gansEEGModel.pt",
path_samples = "gansEEGSyntheticData.csv",
num_samples_total = 10000
)
generate_samples(argv)
#The GAN training results fom the last step results in a file named checkpoint.pt. If you want to continue with this file, use the following line of code:
'''
argv = dict(
file = "checkpoint.pt",
path_samples = "gansEEGSyntheticData.csv",
num_samples_total = 10000
)
generate_samples(argv)
'''
----------------------------------------- Command line arguments: ----------------------------------------- File: gansEEGModel.pt Saving generated samples to file: gansEEGSyntheticData.csv Total number of generated samples: 10000 ----------------------------------------- System output: ----------------------------------------- Initializing generator... Generating samples... Generating sequence 1 of 200... Generating sequence 2 of 200... Generating sequence 3 of 200... Generating sequence 4 of 200... Generating sequence 5 of 200... Generating sequence 6 of 200... Generating sequence 7 of 200... Generating sequence 8 of 200... Generating sequence 9 of 200... Generating sequence 10 of 200... Generating sequence 11 of 200... Generating sequence 12 of 200... Generating sequence 13 of 200... Generating sequence 14 of 200... Generating sequence 15 of 200... Generating sequence 16 of 200... Generating sequence 17 of 200... Generating sequence 18 of 200... Generating sequence 19 of 200... Generating sequence 20 of 200... Generating sequence 21 of 200... Generating sequence 22 of 200... Generating sequence 23 of 200... Generating sequence 24 of 200... Generating sequence 25 of 200... Generating sequence 26 of 200... Generating sequence 27 of 200... Generating sequence 28 of 200... Generating sequence 29 of 200... Generating sequence 30 of 200... Generating sequence 31 of 200... Generating sequence 32 of 200... Generating sequence 33 of 200... Generating sequence 34 of 200... Generating sequence 35 of 200... Generating sequence 36 of 200... Generating sequence 37 of 200... Generating sequence 38 of 200... Generating sequence 39 of 200... Generating sequence 40 of 200... Generating sequence 41 of 200... Generating sequence 42 of 200... Generating sequence 43 of 200... Generating sequence 44 of 200... Generating sequence 45 of 200... Generating sequence 46 of 200... Generating sequence 47 of 200... Generating sequence 48 of 200... Generating sequence 49 of 200... Generating sequence 50 of 200... Generating sequence 51 of 200... Generating sequence 52 of 200... Generating sequence 53 of 200... Generating sequence 54 of 200... Generating sequence 55 of 200... Generating sequence 56 of 200... Generating sequence 57 of 200... Generating sequence 58 of 200... Generating sequence 59 of 200... Generating sequence 60 of 200... Generating sequence 61 of 200... Generating sequence 62 of 200... Generating sequence 63 of 200... Generating sequence 64 of 200... Generating sequence 65 of 200... Generating sequence 66 of 200... Generating sequence 67 of 200... Generating sequence 68 of 200... Generating sequence 69 of 200... Generating sequence 70 of 200... Generating sequence 71 of 200... Generating sequence 72 of 200... Generating sequence 73 of 200... Generating sequence 74 of 200... Generating sequence 75 of 200... Generating sequence 76 of 200... Generating sequence 77 of 200... Generating sequence 78 of 200... Generating sequence 79 of 200... Generating sequence 80 of 200... Generating sequence 81 of 200... Generating sequence 82 of 200... Generating sequence 83 of 200... Generating sequence 84 of 200... Generating sequence 85 of 200... Generating sequence 86 of 200... Generating sequence 87 of 200... Generating sequence 88 of 200... Generating sequence 89 of 200... Generating sequence 90 of 200... Generating sequence 91 of 200... Generating sequence 92 of 200... Generating sequence 93 of 200... Generating sequence 94 of 200... Generating sequence 95 of 200... Generating sequence 96 of 200... Generating sequence 97 of 200... Generating sequence 98 of 200... Generating sequence 99 of 200... Generating sequence 100 of 200... Generating sequence 101 of 200... Generating sequence 102 of 200... Generating sequence 103 of 200... Generating sequence 104 of 200... Generating sequence 105 of 200... Generating sequence 106 of 200... Generating sequence 107 of 200... Generating sequence 108 of 200... Generating sequence 109 of 200... Generating sequence 110 of 200... Generating sequence 111 of 200... Generating sequence 112 of 200... Generating sequence 113 of 200... Generating sequence 114 of 200... Generating sequence 115 of 200... Generating sequence 116 of 200... Generating sequence 117 of 200... Generating sequence 118 of 200... Generating sequence 119 of 200... Generating sequence 120 of 200... Generating sequence 121 of 200... Generating sequence 122 of 200... Generating sequence 123 of 200... Generating sequence 124 of 200... Generating sequence 125 of 200... Generating sequence 126 of 200... Generating sequence 127 of 200... Generating sequence 128 of 200... Generating sequence 129 of 200... Generating sequence 130 of 200... Generating sequence 131 of 200... Generating sequence 132 of 200... Generating sequence 133 of 200... Generating sequence 134 of 200... Generating sequence 135 of 200... Generating sequence 136 of 200... Generating sequence 137 of 200... Generating sequence 138 of 200... Generating sequence 139 of 200... Generating sequence 140 of 200... Generating sequence 141 of 200... Generating sequence 142 of 200... Generating sequence 143 of 200... Generating sequence 144 of 200... Generating sequence 145 of 200... Generating sequence 146 of 200... Generating sequence 147 of 200... Generating sequence 148 of 200... Generating sequence 149 of 200... Generating sequence 150 of 200... Generating sequence 151 of 200... Generating sequence 152 of 200... Generating sequence 153 of 200... Generating sequence 154 of 200... Generating sequence 155 of 200... Generating sequence 156 of 200... Generating sequence 157 of 200... Generating sequence 158 of 200... Generating sequence 159 of 200... Generating sequence 160 of 200... Generating sequence 161 of 200... Generating sequence 162 of 200... Generating sequence 163 of 200... Generating sequence 164 of 200... Generating sequence 165 of 200... Generating sequence 166 of 200... Generating sequence 167 of 200... Generating sequence 168 of 200... Generating sequence 169 of 200... Generating sequence 170 of 200... Generating sequence 171 of 200... Generating sequence 172 of 200... Generating sequence 173 of 200... Generating sequence 174 of 200... Generating sequence 175 of 200... Generating sequence 176 of 200... Generating sequence 177 of 200... Generating sequence 178 of 200... Generating sequence 179 of 200... Generating sequence 180 of 200... Generating sequence 181 of 200... Generating sequence 182 of 200... Generating sequence 183 of 200... Generating sequence 184 of 200... Generating sequence 185 of 200... Generating sequence 186 of 200... Generating sequence 187 of 200... Generating sequence 188 of 200... Generating sequence 189 of 200... Generating sequence 190 of 200... Generating sequence 191 of 200... Generating sequence 192 of 200... Generating sequence 193 of 200... Generating sequence 194 of 200... Generating sequence 195 of 200... Generating sequence 196 of 200... Generating sequence 197 of 200... Generating sequence 198 of 200... Generating sequence 199 of 200... Generating sequence 200 of 200... Saving samples... Generated samples were saved to generated_samples/gansEEGSyntheticData.csv
'\nargv = dict(\n file = "checkpoint.pt",\n path_samples = "gansEEGSyntheticData.csv",\n num_samples_total = 10000\n)\n\ngenerate_samples(argv)\n'
Step 3. Synthetic Data¶
Step 3.1. Load Data¶
We will now load the synthetic data we just produced, and confirm the number of samples per condition
#Load Data
syntheticEEG = np.genfromtxt('generated_samples/gansEEGSyntheticData.csv', delimiter=',', skip_header=1)
#Print head of the data
print(printFormat.bold + 'Display first few rows/columns of data' + printFormat.end)
print(['Condition','Time1','Time2','Time3','Time4','Time5'])
print(syntheticEEG[0:3,0:6])
#Print condition sample counts
print('\n' + printFormat.bold + 'Display trial counts for each condition' + printFormat.end)
print(printFormat.bold +'Win: ' + printFormat.end + str(np.sum(syntheticEEG[:,0]==0)))
print(printFormat.bold +'Lose: ' + printFormat.end + str(np.sum(syntheticEEG[:,0]==1)))
Display first few rows/columns of data ['Condition', 'Time1', 'Time2', 'Time3', 'Time4', 'Time5'] [[0. 0.4462961 0.43793494 0.41811001 0.40980363 0.43475217] [1. 0.42722881 0.44610953 0.4300316 0.45458001 0.45692587] [0. 0.38341969 0.35869789 0.37482953 0.40516382 0.4458825 ]] Display trial counts for each condition Win: 5000 Lose: 5000
Step 3.2. View Data¶
Step 3.2.1. View Trial-Level Data¶
We will view 5 trial level data for both the empirical and synthetic data.
#Determine 5 random trials to plot
empiricalIndex = rnd.sample(range(0, empiricalEEG.shape[0]), 5)
syntheticIndex = rnd.sample(range(0, syntheticEEG.shape[0]), 5)
#Plot trial data
f, ax = plt.subplots(5, 2, figsize=(12, 4))
for c in range(5):
ax[c,0].plot(time,empiricalEEG[empiricalIndex[c],3:]) #Note, we here add the same filter simply for visualization
ax[c,0].set_yticks([])
ax[c,1].plot(time,syntheticEEG[syntheticIndex[c],1:])
ax[c,1].spines[['left', 'right', 'top']].set_visible(False)
ax[c,1].set_yticks([])
if c == 0:
ax[c,0].set_title('Empirical', loc='left')
ax[c,1].set_title('Synthetic', loc='left')
else:
ax[c,0].set_title(' ')
ax[c,1].set_title(' ')
if c != 4:
ax[c,0].spines[['bottom', 'left', 'right', 'top']].set_visible(False)
ax[c,1].spines[['bottom', 'left', 'right', 'top']].set_visible(False)
ax[c,0].set_xticks([])
ax[c,1].set_xticks([])
else:
ax[c,0].spines[['left', 'right', 'top']].set_visible(False)
ax[c,1].spines[['left', 'right', 'top']].set_visible(False)
ax[c,0].set_xlabel('Time (ms)')
ax[c,1].set_xlabel('Time (ms)')
Step 3.2.2. View ERP Data¶
We will now show the empirical and synthetic ERPs side-by-side for comparison.
#Grand average the synthetic waveforms for each condition
synLossWaveform = np.mean(syntheticEEG[np.r_[syntheticEEG[:,0]==1],1:],axis=0)
synWinWaveform = np.mean(syntheticEEG[np.r_[syntheticEEG[:,0]==0],1:],axis=0)
#Set up figure
f, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
#Plot each empirical waveform (note, we here add the same processing simply for visualization)
ax1.plot(time, scale(winWaveform), label = 'Empirical')
ax1.plot(time, scale(synWinWaveform), label = 'Synthetic')
#Format plot
ax1.set_ylabel('Voltage ($\mu$V)')
ax1.set_xlabel('Time (ms)')
ax1.set_title('Win', loc='left')
ax1.spines[['right', 'top']].set_visible(False)
ax1.tick_params(left = False, labelleft = False)
ax1.legend(frameon=False)
#Plot each synthetic waveform
ax2.plot(time, scale(lossWaveform), label = 'Empirical')
ax2.plot(time, scale(synLossWaveform), label = 'Synthetic')
#Format plot
ax2.set_ylabel('Voltage ($\mu$V)')
ax2.set_xlabel('Time (ms)')
ax2.set_title('Lose', loc='left')
ax2.spines[['right', 'top']].set_visible(False)
ax2.tick_params(left = False, labelleft = False)
ax2.legend(frameon=False)
<matplotlib.legend.Legend at 0x7fd61894ab50>
Step 4. Classification¶
Step 4.1. Preparing Validation Data¶
We also provide a validation dataset with samples not contained in the empirical dataset. Here, we prepare them for classification.
#Set seed for a bit of reproducibility
rnd.seed(1618)
#This function averages trial-level empirical data for each participant and condition
def averageEEG(EEG):
participants = np.unique(EEG[:,0])
averagedEEG = []
for participant in participants:
for condition in range(2):
averagedEEG.append(np.mean(EEG[(EEG[:,0]==participant)&(EEG[:,1]==condition),:], axis=0))
return np.array(averagedEEG)
#Load test data to predict (data that neither the GAN nor the classifier will ever see in training)
EEGDataTest = np.genfromtxt('data/gansEEGValidationData.csv', delimiter=',', skip_header=1)
EEGDataTest = averageEEG(EEGDataTest)[:,1:]
#Extract test outcome and predictor data
y_test = EEGDataTest[:,0]
x_test = EEGDataTest[:,2:]
x_test = scale(x_test,axis = 1)
Step 4.2. Preparing Empirical Data¶
We now prepare the empirical training set. Our predictors will be the entire time series of 100 datapoints, however, in the manuscript we also ran parallel classifications with three extracted EEG features.
#Create participant by condition averages
Emp_train = averageEEG(empiricalEEG)[:,1:]
#Extract the outcomes
Emp_Y_train = Emp_train[:,0]
#Scale the predictors
Emp_X_train = scale(Emp_train[:,2:], axis=1)
#Shuffle the order of samples
trainShuffle = rnd.sample(range(len(Emp_X_train)),len(Emp_X_train))
Emp_Y_train = Emp_Y_train[trainShuffle]
Emp_X_train = Emp_X_train[trainShuffle,:]
Step 4.3. Preparing Augmented Data¶
We will prepare the augmented dataset by first processing the synthetic data as we did with the empirical data, then combining both the empirical and synthetic dataset to create an augmented dataset.
#This function averages trial-level synthetic data in bundles of 50 trials, constrained to each condition
def averageSynthetic(synData):
samplesToAverage = 50
lossSynData = synData[synData[:,0]==0,:]
winSynData = synData[synData[:,0]==1,:]
lossTimeIndices = np.arange(0,lossSynData.shape[0],samplesToAverage)
winTimeIndices = np.arange(0,winSynData.shape[0],samplesToAverage)
newLossSynData = [np.insert(np.mean(lossSynData[int(trialIndex):int(trialIndex)+samplesToAverage,1:],axis=0),0,0) for trialIndex in lossTimeIndices]
newWinSynData = [np.insert(np.mean(winSynData[int(trialIndex):int(trialIndex)+samplesToAverage,1:],axis=0),0,1) for trialIndex in winTimeIndices]
avgSynData = np.vstack((np.asarray(newLossSynData),np.asarray(newWinSynData)))
return avgSynData
#Create 'participant' by condition averages
Syn_train = averageSynthetic(syntheticEEG)
#Extract the outcomes
Syn_Y_train = Syn_train[:,0]
#Scale the predictors
Syn_X_train = scale(Syn_train[:,1:], axis=1)
#Combine empirical and synthetic datasets to create an augmented dataset
Aug_Y_train = np.concatenate((Emp_Y_train,Syn_Y_train))
Aug_X_train = np.concatenate((Emp_X_train,Syn_X_train))
#Shuffle the order of samples
trainShuffle = rnd.sample(range(len(Aug_X_train)),len(Aug_X_train))
Aug_Y_train = Aug_Y_train[trainShuffle]
Aug_X_train = Aug_X_train[trainShuffle,:]
Step 5. Support Vector Machine¶
Step 5.1. Define Search Space¶
#Determine SVM search space
param_grid_SVM = [
{'C': [0.1, 1, 10, 100],
'gamma': [1, 0.1, 0.01, 0.001],
'kernel': ['rbf', 'poly', 'sigmoid']}]
Step 5.2. Classify Empirical Data¶
#Setup tracking variable
predictionScores_SVM = []
#Setup SVM grid search
optimal_params = GridSearchCV(
SVC(),
param_grid_SVM,
refit = True,
verbose = False)
#Conduct classification
optimal_params.fit(Emp_X_train, Emp_Y_train)
SVMOutput = optimal_params.predict(x_test)
#Determine performance
predictResults = classification_report(y_test, SVMOutput, output_dict=True)
predictionScores_SVM.append(round(predictResults['accuracy']*100))
Step 5.3. Classify Augmented Data¶
#Setup SVM grid search
optimal_params = GridSearchCV(
SVC(),
param_grid_SVM,
refit = True,
verbose = False)
#Conduct classification
optimal_params.fit(Aug_X_train, Aug_Y_train)
SVMOutput = optimal_params.predict(x_test)
#Determine performance
predictResults = classification_report(y_test, SVMOutput, output_dict=True)
predictionScores_SVM.append(round(predictResults['accuracy']*100))
#Report results
print('Empirical Classification Accuracy: ' + str(predictionScores_SVM[0]) + '%')
print('Augmented Classification Accuracy: ' + str(predictionScores_SVM[1]) + '%')
Empirical Classification Accuracy: 52% Augmented Classification Accuracy: 66%
Step 6. Neural Network¶
Step 6.1. Define Search Space¶
#Determine neural network search space
param_grid_NN = [
{'hidden_layer_sizes': [(25,), (50,), (25, 25), (50,50), (50,25,50)],
'activation': ['logistic', 'tanh', 'relu'],
'solver': ['sgd', 'adam'],
'alpha': [0.0001, 0.05],
'learning_rate': ['constant', 'invscaling', 'adaptive'],
'max_iter' : [10000]}]
Step 6.2. Classify Empirical Data¶
#Signify computational time
print('This may take a few minutes...')
#Setup tracking variable
predictionScores_NN = []
#Setup neural network grid search
optimal_params = GridSearchCV(
MLPClassifier(),
param_grid_NN,
verbose = True,
n_jobs = -1)
#Conduct classification
optimal_params.fit(Emp_X_train, Emp_Y_train);
neuralNetOutput = MLPClassifier(hidden_layer_sizes=optimal_params.best_params_['hidden_layer_sizes'],
activation=optimal_params.best_params_['activation'],
solver = optimal_params.best_params_['solver'],
alpha = optimal_params.best_params_['alpha'],
learning_rate = optimal_params.best_params_['learning_rate'],
max_iter = optimal_params.best_params_['max_iter'])
neuralNetOutput.fit(Emp_X_train, Emp_Y_train)
y_true, y_pred = y_test , neuralNetOutput.predict(x_test)
#Determine performance
predictResults = classification_report(y_true, y_pred, output_dict=True)
predictScore = round(predictResults['accuracy']*100)
predictionScores_NN.append(predictScore)
This may take a few minutes... Fitting 5 folds for each of 180 candidates, totalling 900 fits
Step 6.3. Classify Augmented Data¶
#Signify computational time
print('This may take twice as long as the empirical neural network classification...')
#Setup neural network grid search
optimal_params = GridSearchCV(
MLPClassifier(),
param_grid_NN,
verbose = True,
n_jobs = -1)
#Conduct classification
optimal_params.fit(Aug_X_train, Aug_Y_train);
neuralNetOutput = MLPClassifier(hidden_layer_sizes=optimal_params.best_params_['hidden_layer_sizes'],
activation=optimal_params.best_params_['activation'],
solver = optimal_params.best_params_['solver'],
alpha = optimal_params.best_params_['alpha'],
learning_rate = optimal_params.best_params_['learning_rate'],
max_iter = optimal_params.best_params_['max_iter'])
neuralNetOutput.fit(Aug_X_train, Aug_Y_train)
y_true, y_pred = y_test , neuralNetOutput.predict(x_test)
#Determine performance
predictResults = classification_report(y_true, y_pred, output_dict=True)
predictScore = round(predictResults['accuracy']*100)
predictionScores_NN.append(predictScore)
#Report results
print('Empirical Classification Accuracy: ' + str(predictionScores_NN[0]) + '%')
print('Augmented Classification Accuracy: ' + str(predictionScores_NN[1]) + '%')
This may take twice as long as the empirical neural network classification... Fitting 5 folds for each of 180 candidates, totalling 900 fits Empirical Classification Accuracy: 70% Augmented Classification Accuracy: 69%
Step 7. Final Report¶
Step 7.1. Present Classification Performance¶
We present the performance accuracies in text.
#Report results
print(printFormat.bold + 'SVM Classification Results:' + printFormat.end)
print('Empirical Classification Accuracy: ' + str(predictionScores_SVM[0]) + '%')
print('Augmented Classification Accuracy: ' + str(predictionScores_SVM[1]) + '%')
#Report results
print('\n' + printFormat.bold + 'Neural Network Classification Results:' + printFormat.end)
print('Empirical Classification Accuracy: ' + str(predictionScores_NN[0]) + '%')
print('Augmented Classification Accuracy: ' + str(predictionScores_NN[1]) + '%')
print(