Final Project: Mouse Microbiomes

AC209b | Spring 2021

Group Members: Victor Avram, Blake Bullwinkel, Teresa Datta, and Kristen Grabarz

Video Link: https://harvard.zoom.us/rec/share/fcqGa-4MSyx5K5Q_A_0jOQSy8LKkaBbVE02nMtuX9dL6Or3CwzO8cT2MR8zQ164R.IXfwh9d0K9pP9Y7-?startTime=1620673889000


In [ ]:
# Import relevant packages
import pandas as pd
import numpy as np 
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.metrics import f1_score, confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import MinMaxScaler
import tensorflow as tf
from tensorflow.keras import backend
from tensorflow.keras import Model, Sequential
from tensorflow.keras.layers import Input, SimpleRNN, Embedding, Dense, \
                            TimeDistributed, GRU, Dropout, Bidirectional, \
                            Conv1D, BatchNormalization, LSTM
from tensorflow.keras.models import model_from_json
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.utils import to_categorical
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.metrics import mean_squared_error, r2_score
import pandasql as ps
import random
import editdistance
import seaborn as sns

1. Background, Context & Motivation

1.1 Background & Motivation

The microbiome, composed of trillions of cells, is an important factor for human health. Microbial communities are critical in influencing function of the digestive tract, brain, skin, immunity, and reproductive systems, and disruptions to the balance of one’s microbiome can give rise to negative health effects such as infections or neurological diseases. In light of this, and given the growing interest in microbiome-based therapies such as bacteriotherapies, a wide body of previous research has illuminated important insights. In recent decades, the human microbiome has garnered broad attention from researchers — for example, the Human Microbiome Project) represents an initiative to understand how microbial components of human genetics and metabolic systems contribute to physiology and disease.

1.2 Previous Literature

Previous research has illuminated a great deal about the benefits and dynamics of the microbiome. Importantly, microbiomes are also inherently dynamic, and can change over time due to such factors as maturation, diet changes, illness or medical interventions. A body of research, leveraging technology and computational methods, has focused on elucidating the dynamics of microbiota and exploring methods of modeling it mathematically in response to natural temporal variability and perturbations like antibiotics or dietary changes. Novel methods of inferring dynamical systems models from microbiome time series data have emerged, such as Microbial Dynamical Systems Inference Engine (MDSINE), which has been shown to perform accurate forecasting of microbial dynamics, prediction of stable sub-communities inhibiting pathogen growth, and identification of bacteria that is vital to community integrity in response to perturbations. Further, challenges of microbiome data include high dimensionality, temporal sparsity, and non-uniformly sampled data, and methods such as a Bayesian nonparametric model based on interaction modules (ICML) have been shown to be effective in gaining biological insights into microbial therapy design.

Because of the link between microbial communities and health, an area of work has evolved to develop microbiome-based therapies such as faecal microbiota transplantation, in which healthy bacteria is transferred through stool from a donor into the intestine of a patient to restore the microbial balance and help fight infection. Research has indicated that despite nuances in microbial communities from one subject to another, gut and mouth microbiomes display universal dynamics, which can be useful in designing generalized microbiome therapy methods. In this way, mice have been used to address how bacteriophages impact bacterial communities in the gut. Hsu et al. utilize this approach to specifically show how phages not only can coexist overtime with specific bacteria, but can also specifically induce cascading effects on not directly targeted microorganisms and impact the gut metabolome.

1.3 Research Questions

Motivated by this previous work, along with deep learning methods explored throughout the class, our project aims to leverage deep learning methods to perform quantitiative regression on time series qPCR measurement data, as well as classification of donor health status based on microbiome profile.

2. Description of Data

For this project, we work with data from a mouse study to better understand the workings of the microbiome. Over the span of 70 days, various perturbations were performed on mice with microbiomes received from either healthy or infected sample donors. Perturbations included a high fat diet, Vancomycin, and Gentamicin, and were applied to mice at the same point. in time. An overview of the study design is shown in the following figure

We utilized five TSV files that were provided by the teaching team:

  • asv_and_taxonomy.tsv : information on taxonomy of each ASV bacteria
  • counts.tsv: Relative abundance of different taxa of bacteria in each sample
  • metadata.tsv: Connects each sampleID to mouse identifier and date of sample
  • perturbations.tsv: When each mouse subject underwent each perturbation
  • qpcr.tsv: Total bacterial concentrations for each sample
In [ ]:
# Read in relevant datasets
asv_and_taxonomy = pd.read_csv('data/asv_and_taxonomy.tsv', sep='\t')
metadata = pd.read_csv('data/metadata.tsv', sep='\t')
perturbations = pd.read_csv('data/perturbations.tsv', sep='\t')
qpcr = pd.read_csv('data/qpcr.tsv', sep='\t')
counts = pd.read_csv('data/counts.tsv', sep='\t')
metadata = pd.read_csv('data/metadata.tsv', sep='\t')

2.1 Data Cleaning & Preparation

Along with the information provided by the guest lecturers, we performed initial explorations of the data through visualizations which we will discuss in the next section. Missing data was not found to be a concern. The first step in our data cleaning process was to transform the ASV counts to relative abundances for interpretation and use in our future models. Recall that counts are not absolute, but are rather relative to the read depth. The total read depth for each sample can be computed by summing the counts, and then relative abundance can be derived by dividing counts by the read depth.

In [ ]:
# Convert the first column specifying the species to the index
counts = counts.set_index('Unnamed: 0', drop=True)

# Perform sequencing depth normalization
read_depth=counts.sum()
counts = counts.div(read_depth, axis=1)

Next, we performed additional cleaning of the ASV count data by limiting it to only those species that appear at non-negligible frequencies. To accomplish this, we filtered out any species in the ASV counts dataset having a maximum relative abundance of less than 0.00005. In doing so, we found that nearly 40% of bacteria taxa (587 out of 1473 rows) had max relative abundance of less than 0.00005 and were removed.

In [ ]:
def remove_nonspecies(c): 
    # removes all bacteria taxa species (rows) in counts df which have a max relative abundance of less than 0.00005
    c['max_count'] = c.max(axis = 1)
    c = c[c['max_count']> 0.00005]
    c.drop(columns = ['max_count'])
    return c

counts_cleaned = remove_nonspecies(counts)

num_removed = counts.shape[0]- counts_cleaned.shape[0]
print(num_removed,"out of", counts.shape[0], "rows (bacteria taxa) had max relative abundance of less than 0.00005 and were removed")
587 out of 1473 rows (bacteria taxa) had max relative abundance of less than 0.00005 and were removed

3. Select EDA

Before building models, we performed exploratory data analysis to build contextual knowledge of the data's structure and patterns, as well as provide footing for our subsequnt modeling objectives.

In [ ]:
# Determine the number of bateria species and the number of samples
num_species,num_samples = counts.shape
print(f"There are {num_species} unique species.")
print(f"There are {num_samples} samples.")

# Determine the number of unique classes and the number of unique orders
num_classes,num_orders = asv_and_taxonomy['Class'].nunique(),asv_and_taxonomy['Order'].nunique()
print(f"There are {num_classes} unique classes.")
print(f"There are {num_orders} unique orders.")
There are 1473 unique species.
There are 687 samples.
There are 24 unique classes.
There are 43 unique orders.

3.1 qPCR Over Time for Healthy & UC Donor Mice

First, we explore the qPCR measurements for a healthy and UC donor mouse to assess the trends over time and visualize responses to perturbations.

In the plots below, we show the qPCR time-series for Mice 2 and 6, each of which has three measurements indicating the total bacterial concentrations in the mice over the course of the study. While there are some similarities between the plots, it is clear that the mice respond differently to the same perturbations. In particular, Mouse 2 appears to have larger perturbations during the Gentamicin period, while Mouse 6 had very large perturbations during the Vancomycin period.

In [ ]:
# Plot qPCR for Healthy v. UC Donor Mouse

# Make a copy of qpcr for plotting
qpcr_copy = qpcr.copy()

# Add the 'subject' and 'time' in metadata to qpcr
qpcr_copy = qpcr_copy.merge(metadata, left_on='sampleID', right_on='sampleID')

# Filter the data for mouse 2 and mouse 6
mouse2_data = qpcr_copy[qpcr_copy['subject'] == 2]
mouse6_data = qpcr_copy[qpcr_copy['subject'] == 6]

# Define variables for plotting
measurements = ['measurement1', 'measurement2', 'measurement3']
mice_nums = [2, 6]
dfs = [mouse2_data, mouse6_data]

# Plot the time series of measurement1, measurement2, measurement3 for mouse 2 and mouse 6
fig, axs = plt.subplots(1, 2, figsize=(21, 6))
for i in range(2):
    for m in measurements:
        sorted_idx = np.argsort(dfs[i]['time'].to_numpy())
        sorted_measurements = dfs[i][m].to_numpy()[sorted_idx]
        axs[i].plot(sorted(dfs[i]['time']), sorted_measurements, label=m)
    axs[i].set_xlabel('Time')
    axs[i].set_ylabel('qPCR (CFU/g)')
    axs[i].set_title('qPCR Time-Series for Mouse Subject '+str(mice_nums[i]))
    axs[i].legend()
plt.show()

3.2 Relative Abundance of Taxa over Time by Healthy vs. UC Donor Mice

As part of our EDA, we explored how the relative abundance of bacterial species changes over time among the mice. Specifically, we are interested in what happens during and after a given perturbation (i.e. introduction of a broad-spectrum antibiotic). The figures below show the changes over time in the relative abundance of different phyla for a mouse with a healthy donor and one with a UC donor. We also explored breakdowns by other taxonomic groupings, such as class or family, but excluded these from our final code report for the sake of brevity. A breakdown by kingdom indicated thta dominant kingdom present in both mice is Bacteria. There are no obvious changes in the relative abundance of either kingdom (there are effectively no Archaea relative to Bacteria).

Vancomycin was introduced at 35.5 days and stopped at 42.5 days. Gentamycin was introduced at 50.5 days and stopped at 57.5 days. Interestingly, we see a similar relative abundance pattern between the two mice. However, exposure to Vancomycin produces a noticeable increase in the abundance of proteobacteria along with a reduction in the abundance of Actinobacteria in mouse 2 (healthy donor). This signature is not found in mouse 6 (unhealthy donor).

Additionally, these results indicate that at a given point in time, the majority of each mouse's microbiome can be broken down into four to five groups. Also, the breakdown of relative abundances does not change materially with more granular taxonomic breakdowns (e.g. a similar number of groups make up most of the relative abundance at the Phylum level as at the Class or Order level). This suggests that grouping the taxa at the Phylum or Class level should be sufficient to understand broad changes in the microbiomes over time.

In [ ]:
# Group the counts dataframe by phylum and compute the relative abundance per phylum
phylum_counts = counts.assign(Class = asv_and_taxonomy['Phylum'].to_numpy()).groupby('Class').sum()

# Pull the samples associated with mouse 2 (inoculated with a healthy microbiome) and mouse 6 (inoculated with an unhealthy microbiome)
mouse2_counts = phylum_counts.loc[:,phylum_counts.columns.str.startswith('2')]
mouse6_counts = phylum_counts.loc[:,phylum_counts.columns.str.startswith('6')]
mouse2_counts = mouse2_counts.transpose()
mouse6_counts = mouse6_counts.transpose()

# Pull the metadata associated with mouse 2 and the metadata associated with mouse 6
mouse2_metadata = metadata.loc[metadata['sampleID'].str.startswith('2'),:]
mouse6_metadata = metadata.loc[metadata['sampleID'].str.startswith('6'),:]

# Sort the metadata and the counts data based on the time
mouse2_ordered_indices = np.argsort(mouse2_metadata['time'])
mouse2_metadata = mouse2_metadata.iloc[mouse2_ordered_indices,:]
mouse2_counts = mouse2_counts.iloc[mouse2_ordered_indices,:]
mouse6_ordered_indices = np.argsort(mouse6_metadata['time'])
mouse6_metadata = mouse6_metadata.iloc[mouse6_ordered_indices,:]
mouse6_counts = mouse6_counts.iloc[mouse6_ordered_indices,:]
In [ ]:
# Create plots of the relative abundance of each phylum
fig, axes = plt.subplots(1, 2, figsize = (14, 6))
ax = axes.flatten()
# Iterate over the classes for mouse 2
for col in mouse2_counts:
    ax[0].plot(mouse2_metadata['time'], mouse2_counts[col], label = col)
ax[0].set_xlabel("Time (days)", fontsize = 14)
ax[0].set_ylabel("Relative Abundance", fontsize = 14)
ax[0].set_title("Mouse 2 - Relative Abundance vs Time", fontsize = 14)
# Iterate over the classes for mouse 2
for col in mouse6_counts:
    ax[1].plot(mouse6_metadata['time'], mouse6_counts[col], label = col)
ax[1].set_xlabel("Time (days)", fontsize = 14)
ax[1].set_ylabel("Relative Abundance", fontsize = 14)
ax[1].set_title("Mouse 6 - Relative Abundance vs Time", fontsize = 14)
ax[1].legend(loc='center left', bbox_to_anchor=(1.05, 0.5), fontsize=12)

fig.suptitle('Phylum Relative Abundance', fontsize=16)
plt.tight_layout();

plt.savefig('plots/phylum_abundance_over_time.png')

3.3 ASV Gene Groupings

We also sought to understand how we might identify biological groups based on the ASV gene sequences themselves. To explore this, we calculated the edit distance (the minimum number of operations required to transform one string into the other) between each unique amplicon sequence and all of the other sequences, again excluding the taxa identified as having extremely low relative abundances as noted above.

As can be seen in the plots below, which includes both a small sample close-up view and a version with all 886 non-negligible species sequences, many of the taxa have similar genetic profiles, reinforcing the hypothesis that clustering the ASVs taxonomically might help provide an aggregated view of the mice microbiomes.

In [ ]:
# Create heat map of genetic sequence edit distances

# Add name as col for counts clean
counts_clean_w_names = counts_cleaned.copy()
counts_clean_w_names.index.name = 'name'
counts_clean_w_names.reset_index(inplace=True)

# Join counts_cleaned to asv_and_taxonomy to get grouping options
counts_and_tax = counts_clean_w_names.merge(asv_and_taxonomy, on = 'name')
sequences = counts_and_tax['sequence'].unique()
sequences_subset = sequences[:20]

def create_heatmap(sequences):
    n = len(sequences)
    distance_matrix = np.zeros((n,n))#[[0]*len(sequences)]*len(sequences)

    # N x N Heat Map
    for i in range(len(sequences)):
        for j in range(len(sequences)):
            distance = editdistance.eval(sequences[i], sequences[j])
            distance_matrix[i][j]=distance
    
    return distance_matrix

distance_mat = create_heatmap(sequences_subset)
distance_mat_full = create_heatmap(sequences)

fig, axes = plt.subplots(1, 2, figsize=(12, 5))
fig.suptitle('Heatmap of Edit Distances Between ASV Sequences')
sns.heatmap(distance_mat, ax=axes[0])
axes[0].set_title('Sample of Sequences')
sns.heatmap(distance_mat_full, ax=axes[1])
axes[1].set_title('All Non-Negligible ASV Sequences')
Out[ ]:
Text(0.5, 1.0, 'All Non-Negligible ASV Sequences')

4. Modeling

4.1 Predicting qPCR Measurements

Given the time-series format of the qPCR measurements, we were interested in using autoregressive models to predict future readings based on lagged qPCR measurements. Specifically, we split the data into measurements that came from mice that received transplants from healthy human donors versus human donors infected with ulcerative colitis (UC) and trained our models on three lagged measurements. We decided to split the data after noticing that models trained on the full data set performed quite poorly, most likely because the nature of qPCR time-series for mice in the two groups are very different.

Before modeling, we used the functions below to reformat the data into a form that is easier to work with and transform the qPCR measurements onto the log10 scale to avoid exploding gradients when training neural network models.

In [ ]:
# Define a function to get the nth previous measurement for a particular qPCR measurement m
def get_prev_measurement(x, n, m): 
    subj = int(metadata[metadata['sampleID']==x]['subject'])
    date = int(metadata[metadata['sampleID']==x]['time'])
    samples = metadata[metadata['subject']==subj]
    samples = samples[samples['time']<date].sort_values(by = ['time'], ascending = False)
    n = n-1
    if samples.shape[0]>n:
        s = str(samples.iloc[n]['sampleID'])
    else: 
        return "nan"
    if m == 1:
        return float(qpcr[qpcr['sampleID'] == s]['measurement1'])
    elif m == 2:
        return float(qpcr[qpcr['sampleID'] == s]['measurement2'])
    elif m == 3:
        return float(qpcr[qpcr['sampleID'] == s]['measurement3'])
    else:
        raise Exception('measurement m must be in [1,2,3]')

# Define a function to get the full date frame of lagged qPCR measurements
def get_qpcr_full(df=qpcr, timesteps=3):

    # Get the previous three measurements for each measurement of each row in qpcr
    qpcr_bm = df.merge(metadata, left_on='sampleID', right_on='sampleID')
    for m in range(1,4): # Loop through measurements m
        for t in range(1,timesteps+1): # Loop through previous timesteps
            col_name = 'm'+str(m)+'_prev_'+str(t)
            qpcr_bm[col_name] = qpcr_bm.sampleID.apply(lambda x: get_prev_measurement(x,t,m))
        qpcr_bm = qpcr_bm[qpcr_bm[col_name]!= 'nan'] # At this point, col_name will be last timestep

    # Add the donor indicator variable
    qpcr_bm['donor'] = np.where(qpcr_bm.subject.isin([6,7,8,9,10]), 1, 0)

    # Add the perturbation binary variables (note start and end time are the same for all mice)
    qpcr_bm['diet'] = qpcr_bm.time.between(21.5, 28.5).astype(int)
    qpcr_bm['vancomycin'] = qpcr_bm.time.between(35.5, 42.5).astype(int)
    qpcr_bm['gentamicin'] = qpcr_bm.time.between(50.5, 57.5).astype(int)

    # Initialize and fill a new data frame with one measurement on each row
    col_names = ['measurement', 'donor', 'diet', 'vancomycin', 'gentamicin']
    for t in range(1,timesteps+1):
        col_names.append('prev_'+str(t))
    qpcr_full = pd.DataFrame(columns=col_names)
    for i, row in qpcr_bm.iterrows():
        for m in range(1,4):
            measurement_col = 'measurement'+str(m)
            new_row = [np.log(row[measurement_col]), row['donor'], row['diet'], row['vancomycin'], row['gentamicin']]
            for t in range(1,timesteps+1):
                new_row.append(np.log(row['m'+str(m)+'_prev_'+str(t)])) # transform qPCR to log scale
            qpcr_full.loc[3*i+(m-1)] = new_row
    return qpcr_full
In [ ]:
# Call the function and print out the first few lines
qpcr_full = get_qpcr_full(qpcr, timesteps=3)
qpcr_full.head(10)
Out[ ]:
measurement donor diet vancomycin gentamicin prev_1 prev_2 prev_3
3 25.551351 1.0 0.0 0.0 0.0 25.883120 25.945696 26.431573
4 25.718158 1.0 0.0 0.0 0.0 25.401576 26.332394 25.465093
5 24.795046 1.0 0.0 0.0 0.0 25.585563 26.149845 26.567201
6 25.686418 1.0 0.0 0.0 0.0 25.551351 25.883120 25.945696
7 26.096020 1.0 0.0 0.0 0.0 25.718158 25.401576 26.332394
8 26.404120 1.0 0.0 0.0 0.0 24.795046 25.585563 26.149845
9 25.010251 1.0 0.0 0.0 0.0 25.686418 25.551351 25.883120
10 25.996749 1.0 0.0 0.0 0.0 26.096020 25.718158 25.401576
11 24.964518 1.0 0.0 0.0 0.0 26.404120 24.795046 25.585563
12 25.262775 1.0 0.0 0.0 0.0 25.010251 25.686418 25.551351
In [ ]:
# Split the full data into train and test
X = qpcr_full[['prev_1', 'prev_2', 'prev_3']]
y = qpcr_full['measurement']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)

# Split the data into healthy and UC donor status
qpcr_healthy = qpcr_full[qpcr_full['donor']==0]
qpcr_uc = qpcr_full[qpcr_full['donor']==1]

# Split each of these into train and test
X_healthy = qpcr_healthy[['prev_1', 'prev_2', 'prev_3']]
y_healthy = qpcr_healthy['measurement']
X_healthy_train, X_healthy_test, y_healthy_train, y_healthy_test = train_test_split(X_healthy, y_healthy, test_size=0.25, random_state=42)
X_uc = qpcr_uc[['prev_1', 'prev_2', 'prev_3']]
y_uc = qpcr_uc['measurement']
X_uc_train, X_uc_test, y_uc_train, y_uc_test = train_test_split(X_uc, y_uc, test_size=0.25, random_state=42)

4.1.1 Autoregressive Model

First, we train an autoregressive model on the separate healthy and UC donor measurements in order to see whether we can predict future qPCR measurements based on linear combinations of the three previous measurements.

In [ ]:
# Train and evaluate an autoregressive model on the healthy donor measurements
healthy_autoreg = LinearRegression()
healthy_autoreg.fit(X_healthy_train, y_healthy_train)
healthy_autoreg_train_pred = healthy_autoreg.predict(X_healthy_train)
healthy_autoreg_test_pred = healthy_autoreg.predict(X_healthy_test)
print('Autoregressive Model Results - Healthy Donor')
print('Train MSE: %.4f'
      % mean_squared_error(y_healthy_train, healthy_autoreg_train_pred))
print('Train R2: %.4f'
      % r2_score(y_healthy_train, healthy_autoreg_train_pred))
print('Test MSE: %.4f'
      % mean_squared_error(y_healthy_test, healthy_autoreg_test_pred))
print('Test R2: %.4f'
      % r2_score(y_healthy_test, healthy_autoreg_test_pred))

# Train and evaluate an autoregressive model on the UC donor measurements
uc_autoreg = LinearRegression()
uc_autoreg.fit(X_uc_train, y_uc_train)
uc_autoreg_train_pred = uc_autoreg.predict(X_uc_train)
uc_autoreg_test_pred = uc_autoreg.predict(X_uc_test)
print('Autoregressive Model Results - UC Donor')
print('Train MSE: %.4f'
      % mean_squared_error(y_uc_train, uc_autoreg_train_pred))
print('Train R2: %.4f'
      % r2_score(y_uc_train, uc_autoreg_train_pred))
print('Test MSE: %.4f'
      % mean_squared_error(y_uc_test, uc_autoreg_test_pred))
print('Test R2: %.4f'
      % r2_score(y_uc_test, uc_autoreg_test_pred))
Autoregressive Model Results - Healthy Donor
Train MSE: 0.6188
Train R2: 0.3419
Test MSE: 0.5134
Test R2: 0.4228
Autoregressive Model Results - UC Donor
Train MSE: 0.7372
Train R2: 0.2253
Test MSE: 1.2492
Test R2: 0.1479

4.1.2 Long-Short Term Memory (LSTM)

Next, we use the same data to train a Long Short-Term Memory (LSTM) network, which is a type of Recurrent Neural Network (RNN) architecture. By introducing concepts of memory and context history, they are particularly well suited for time series data. Before modeling, we first reshape the data into the format required by Keras.

In [ ]:
# Reshape original data
X_train = np.array(X_train)
X_train = X_train.reshape((X_train.shape[0], X_train.shape[1], 1))
y_train = np.array(y_train)
X_test = np.array(X_test)
X_test = X_test.reshape((X_test.shape[0], X_test.shape[1], 1))
y_test = np.array(y_test)

# Reshape healthy donor arrays
X_healthy_train = np.array(X_healthy_train)
X_healthy_train = X_healthy_train.reshape((X_healthy_train.shape[0], X_healthy_train.shape[1], 1))
y_healthy_train = np.array(y_healthy_train)
X_healthy_test = np.array(X_healthy_test)
X_healthy_test = X_healthy_test.reshape((X_healthy_test.shape[0], X_healthy_test.shape[1], 1))
y_healthy_test = np.array(y_healthy_test)

# Reshape UC donor arrays
X_uc_train = np.array(X_uc_train)
X_uc_train = X_uc_train.reshape((X_uc_train.shape[0], X_uc_train.shape[1], 1))
y_uc_train = np.array(y_uc_train)
X_uc_test = np.array(X_uc_test)
X_uc_test = X_uc_test.reshape((X_uc_test.shape[0], X_uc_test.shape[1], 1))
y_uc_test = np.array(y_uc_test)
In [ ]:
# Define a function that creates and compiles a LSTM
def get_lstm(model_name):
    lstm = tf.keras.Sequential(name=model_name)
    lstm.add(Bidirectional(LSTM(64), input_shape=(3,1)))
    lstm.add(Dense(128, activation = "relu"))
    lstm.add(Dense(64, activation = "relu"))
    lstm.add(Dense(1))
    lstm.compile(optimizer=tf.optimizers.Adam(1e-4), loss='mean_squared_error')
    return lstm

# Instantiate an LSTM model for the healthy donor and print summary
lstm_healthy = get_lstm('Healthy')
print(lstm_healthy.summary())
Model: "Healthy"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
bidirectional (Bidirectional (None, 128)               33792     
_________________________________________________________________
dense (Dense)                (None, 128)               16512     
_________________________________________________________________
dense_1 (Dense)              (None, 64)                8256      
_________________________________________________________________
dense_2 (Dense)              (None, 1)                 65        
=================================================================
Total params: 58,625
Trainable params: 58,625
Non-trainable params: 0
_________________________________________________________________
None
In [ ]:
# Train the model on the healthy donor data
healthy_history = lstm_healthy.fit(X_healthy_train, y_healthy_train, epochs=200, verbose=0)
In [ ]:
# Predict on train and test
y_healthy_train_pred = lstm_healthy.predict(X_healthy_train)
y_healthy_test_pred = lstm_healthy.predict(X_healthy_test)

# Print out train and test MSE and R2 values
print('LSTM Results - Healthy Donor')
print('Train MSE: %.4f'
      % mean_squared_error(y_healthy_train, y_healthy_train_pred))
print('Train R2: %.4f'
      % r2_score(y_healthy_train, y_healthy_train_pred))
print('Test MSE: %.4f'
      % mean_squared_error(y_healthy_test, y_healthy_test_pred))
print('Test R2: %.4f'
      % r2_score(y_healthy_test, y_healthy_test_pred))
LSTM Results - Healthy Donor
Train MSE: 0.6263
Train R2: 0.3340
Test MSE: 0.5365
Test R2: 0.3969
In [ ]:
# Instantiate an LSTM model for the UC donor
lstm_uc = get_lstm('UC')

# Train the model on the UC donor data
uc_history = lstm_uc.fit(X_uc_train, y_uc_train, epochs=200, verbose=0)
In [ ]:
# Predict on train and test
y_uc_train_pred = lstm_uc.predict(X_uc_train)
y_uc_test_pred = lstm_uc.predict(X_uc_test)

# Print out train and test MSE and R2 values
print('LSTM Results - UC Donor')
print('Train MSE: %.4f'
      % mean_squared_error(y_uc_train, y_uc_train_pred))
print('Train R2: %.4f'
      % r2_score(y_uc_train, y_uc_train_pred))
print('Test MSE: %.4f'
      % mean_squared_error(y_uc_test, y_uc_test_pred))
print('Test R2: %.4f'
      % r2_score(y_uc_test, y_uc_test_pred))
LSTM Results - UC Donor
Train MSE: 0.7371
Train R2: 0.2255
Test MSE: 1.2485
Test R2: 0.1484
In [ ]:
# Plot the training MSE for both LSTMs
fig, ax = plt.subplots(1, 2, figsize=(10,4))
epochs = len(healthy_history.history['loss'])
ax[0].plot(range(epochs), healthy_history.history['loss'], label='Training MSE')
ax[0].set_xlabel('Epoch')
ax[0].set_ylabel('Loss')
ax[0].set_yscale('log')
ax[0].set_title('Training Loss for LSTM - Healthy Donor')
ax[1].plot(range(epochs), uc_history.history['loss'], label='Training MSE')
ax[1].set_xlabel('Epoch')
ax[1].set_ylabel('Loss')
ax[1].set_yscale('log')
ax[1].set_title('Training Loss for LSTM - UC Donor')
plt.tight_layout()

As we can see from the outputs above, the autoregressive model and the LSTM perform very similarly on both the healthy donor and the UC donor qPCR measurements. In fact, the autoregressive model slighty outperforms the LSTM on the test sets of each, with lower MSE values and higher coefficients of determination ($R^2$ value). Notably, both models perform significantly better on the healthy donor data, indicating that it is easier to predict the next qPCR measurement for mice that received a sample from a healthy human donor. This might be because the microbiomes of those mice are less variable, whether in response to different perturbations or in general, than those of mice who received a sample from a human infected with UC.

4.2. Predicting Healthy vs. Ulceritive Colitis Donor

Motivated by our observation that qPCR measurements are easier to predict for mice in the healthy donor group, we decided to pursue a secondary classification task modeling donor health status. In particular, we expect that certain features of the relative abundances of taxa in the mice microbiomes will allow us to distinguish between mice in the healthy donor group versus those in the UC donor group. As described above, the nine mice in the experiment were divided into two cohorts: some received a fecal transplant from a healthy human donor, while others received a transplant from a human donor with UC.

Our next set of models attempt to predict, based on relative abundances of taxa in a microbiome along with perturbations, whether a donor was healthy or had UC. This can be thought of as a sequence classification problem, since we are taking a sequence of data (e.g. the series of measurements for each mouse) and predicting a binary classification (e.g. whether that mouse had a healthy or UC donor).

The first step in building this model is to prepare the data such that we have a time series for each mouse of their non-negligible ASV measurements at that point, plus record of their perturbations and donor status.

In [ ]:
# Goal: Format data such that we have full history of each mouse's
# relative abundances along with binary indicator of donor healthy
# and perturbations

# Combine counts with mouse and observation metadata
counts_transposed = counts_cleaned.transpose().reset_index()
all_obs = counts_transposed.merge(metadata, left_on = 'index', right_on = 'sampleID')

# Use SQL-like merge to pull in perturbation timing
sqlcode = '''
select *
from all_obs a
left join perturbations p on a.subject = p.subject and a.time between p.start and p.end
'''
combined = ps.sqldf(sqlcode, locals())

# One hot encode perturbation

# Get one hot encoding of columns B
one_hot = pd.get_dummies(combined.name, prefix = 'perturbation')
# Drop column name as it is now encoded
combined = combined.drop('name',axis = 1)
# Join the encoded df
full_dat = combined.join(one_hot)

# Remove dupe columns
full_dat = full_dat.loc[:,~full_dat.columns.duplicated()]

# Assign donor indicator for mice with healthy vs. UC donor
full_dat['donor'] = np.where(full_dat.subject.isin([6,7,8,9,10]), 1, 0)

4.2.1 Vanilla CNN (Non-Sequential)

First, we fit a Vanilla CNN without taking into account the sequential nature of the data. In this scenario, we treat each measurement for each mouse as its own individual observation. This allows for a greater sample size (despite the small number of mice in the experiment) for training and testing.

To ensure we can evaluate model performance based on both a healthy and a UC mouse, we generate a test set consisting of all the measurements from two mice: one healthy, and one with a UC donor (mice 2 and 6), and use the remaining mice for the training set. For this model, our predictors will consist of all of the ASV measurements, along with a binary indicator at each measurement time point denoting whether that perturbation was being applied at that point in time. The outcome variable is a binary indicator denoting whether a mouse had a healthy or UC donor.

Of the 686 total measurements across all 9 mice, 532 are in our training dataset of 7 mice and 154 are in our test set of 2 mice.

In [ ]:
# Split into test train:
# Since sample size is small try one healthy/one ill mouse in test
print(f'full_dat shape: {full_dat.shape}')

test_dat = full_dat.loc[full_dat['subject'].isin([2,6])]
train_dat= full_dat.loc[full_dat['subject'].isin([3,4,5,7,8,9,10])]
print(f'train_dat shape: {train_dat.shape}')
print(f'test_dat shape: {test_dat.shape}')
 
# Split into predictors and targets for training and evaluation
target_train = train_dat['donor']
input_train = train_dat[[c for c in train_dat.columns if 'ASV' in c or 'perturbation' in c]] # only use ASV identifiers

target_test = test_dat['donor']
input_test = test_dat[[c for c in test_dat.columns if 'ASV' in c or 'perturbation' in c]] # only use ASV identifiers

# Create variable indicating number of features
num_features = input_train.shape[1]
full_dat shape: (686, 896)
train_dat shape: (532, 896)
test_dat shape: (154, 896)
In [ ]:
# Instantiate model and print summary
donor_model_1 = keras.Sequential(
    [
        Input(shape=(num_features)),
        layers.Dense(200, activation="relu", name="layer1"),
        layers.Dense(10, activation="relu", name="layer2"),
        layers.Dense(1, name="layer3"),
    ]
)
donor_model_1.summary()  
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
layer1 (Dense)               (None, 200)               178000    
_________________________________________________________________
layer2 (Dense)               (None, 10)                2010      
_________________________________________________________________
layer3 (Dense)               (None, 1)                 11        
=================================================================
Total params: 180,021
Trainable params: 180,021
Non-trainable params: 0
_________________________________________________________________
In [ ]:
# Compile and train on training mice
donor_model_1.compile(optimizer='sgd', loss='binary_crossentropy', metrics=['accuracy'])
donor_model_1.fit(input_train, target_train, epochs=100)
Epoch 1/100
17/17 [==============================] - 1s 10ms/step - loss: 1.0568 - accuracy: 0.5526
Epoch 2/100
17/17 [==============================] - 0s 13ms/step - loss: 0.5742 - accuracy: 0.7154
Epoch 3/100
17/17 [==============================] - 0s 9ms/step - loss: 0.5207 - accuracy: 0.7515
Epoch 4/100
17/17 [==============================] - 0s 12ms/step - loss: 0.4693 - accuracy: 0.8074
Epoch 5/100
17/17 [==============================] - 0s 9ms/step - loss: 0.4106 - accuracy: 0.9278
Epoch 6/100
17/17 [==============================] - 0s 7ms/step - loss: 0.3690 - accuracy: 0.9424
Epoch 7/100
17/17 [==============================] - 0s 6ms/step - loss: 0.3231 - accuracy: 0.9577
Epoch 8/100
17/17 [==============================] - 0s 9ms/step - loss: 0.2701 - accuracy: 0.9838
Epoch 9/100
17/17 [==============================] - 0s 9ms/step - loss: 0.2258 - accuracy: 0.9811
Epoch 10/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1669 - accuracy: 0.9818
Epoch 11/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1343 - accuracy: 0.9825
Epoch 12/100
17/17 [==============================] - 0s 8ms/step - loss: 0.1912 - accuracy: 0.9718
Epoch 13/100
17/17 [==============================] - 0s 9ms/step - loss: 0.3720 - accuracy: 0.9525
Epoch 14/100
17/17 [==============================] - 0s 6ms/step - loss: 0.3351 - accuracy: 0.9280
Epoch 15/100
17/17 [==============================] - 0s 5ms/step - loss: 0.1826 - accuracy: 0.9429
Epoch 16/100
17/17 [==============================] - 0s 4ms/step - loss: 0.1057 - accuracy: 0.9649
Epoch 17/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1037 - accuracy: 0.9801
Epoch 18/100
17/17 [==============================] - 0s 8ms/step - loss: 0.0733 - accuracy: 0.9898
Epoch 19/100
17/17 [==============================] - 0s 8ms/step - loss: 0.0866 - accuracy: 0.9769
Epoch 20/100
17/17 [==============================] - 0s 8ms/step - loss: 0.0947 - accuracy: 0.9858
Epoch 21/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1277 - accuracy: 0.9722
Epoch 22/100
17/17 [==============================] - 0s 6ms/step - loss: 0.1436 - accuracy: 0.9849
Epoch 23/100
17/17 [==============================] - 0s 6ms/step - loss: 0.0773 - accuracy: 0.9831
Epoch 24/100
17/17 [==============================] - 0s 24ms/step - loss: 0.0552 - accuracy: 0.9875
Epoch 25/100
17/17 [==============================] - 0s 16ms/step - loss: 0.0931 - accuracy: 0.9857
Epoch 26/100
17/17 [==============================] - 0s 6ms/step - loss: 0.0709 - accuracy: 0.9835
Epoch 27/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0984 - accuracy: 0.9808
Epoch 28/100
17/17 [==============================] - 0s 8ms/step - loss: 0.0669 - accuracy: 0.9773
Epoch 29/100
17/17 [==============================] - 0s 7ms/step - loss: 0.0712 - accuracy: 0.9824
Epoch 30/100
17/17 [==============================] - 0s 7ms/step - loss: 0.0814 - accuracy: 0.9834
Epoch 31/100
17/17 [==============================] - 0s 8ms/step - loss: 0.0641 - accuracy: 0.9879
Epoch 32/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0679 - accuracy: 0.9769
Epoch 33/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0760 - accuracy: 0.9812
Epoch 34/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0819 - accuracy: 0.9829
Epoch 35/100
17/17 [==============================] - 0s 5ms/step - loss: 0.0841 - accuracy: 0.9811
Epoch 36/100
17/17 [==============================] - 0s 5ms/step - loss: 0.0916 - accuracy: 0.9793
Epoch 37/100
17/17 [==============================] - 0s 5ms/step - loss: 0.0789 - accuracy: 0.9777
Epoch 38/100
17/17 [==============================] - 0s 8ms/step - loss: 0.0590 - accuracy: 0.9841
Epoch 39/100
17/17 [==============================] - 0s 8ms/step - loss: 0.0765 - accuracy: 0.9847
Epoch 40/100
17/17 [==============================] - 0s 8ms/step - loss: 0.0838 - accuracy: 0.9787
Epoch 41/100
17/17 [==============================] - 0s 15ms/step - loss: 0.1461 - accuracy: 0.9770
Epoch 42/100
17/17 [==============================] - 0s 7ms/step - loss: 0.0681 - accuracy: 0.9744
Epoch 43/100
17/17 [==============================] - 0s 12ms/step - loss: 0.0742 - accuracy: 0.9832
Epoch 44/100
17/17 [==============================] - 0s 7ms/step - loss: 0.1556 - accuracy: 0.9690
Epoch 45/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0682 - accuracy: 0.9739
Epoch 46/100
17/17 [==============================] - 0s 7ms/step - loss: 0.0992 - accuracy: 0.9757
Epoch 47/100
17/17 [==============================] - 0s 4ms/step - loss: 0.0659 - accuracy: 0.9824
Epoch 48/100
17/17 [==============================] - 0s 6ms/step - loss: 0.0742 - accuracy: 0.9818
Epoch 49/100
17/17 [==============================] - 0s 4ms/step - loss: 0.0445 - accuracy: 0.9887
Epoch 50/100
17/17 [==============================] - 0s 6ms/step - loss: 0.0761 - accuracy: 0.9807
Epoch 51/100
17/17 [==============================] - 0s 5ms/step - loss: 0.0370 - accuracy: 0.9898
Epoch 52/100
17/17 [==============================] - 0s 5ms/step - loss: 0.0835 - accuracy: 0.9791
Epoch 53/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0643 - accuracy: 0.9838
Epoch 54/100
17/17 [==============================] - 0s 8ms/step - loss: 0.0546 - accuracy: 0.9833
Epoch 55/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0663 - accuracy: 0.9891
Epoch 56/100
17/17 [==============================] - 0s 8ms/step - loss: 0.0665 - accuracy: 0.9802
Epoch 57/100
17/17 [==============================] - 0s 6ms/step - loss: 0.0656 - accuracy: 0.9813
Epoch 58/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0532 - accuracy: 0.9895
Epoch 59/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0532 - accuracy: 0.9799
Epoch 60/100
17/17 [==============================] - 0s 8ms/step - loss: 0.1127 - accuracy: 0.9742
Epoch 61/100
17/17 [==============================] - 0s 15ms/step - loss: 0.1106 - accuracy: 0.9804
Epoch 62/100
17/17 [==============================] - 0s 11ms/step - loss: 0.0813 - accuracy: 0.9838
Epoch 63/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0300 - accuracy: 0.9897
Epoch 64/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0665 - accuracy: 0.9821
Epoch 65/100
17/17 [==============================] - 0s 9ms/step - loss: 0.1041 - accuracy: 0.9841
Epoch 66/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0449 - accuracy: 0.9841
Epoch 67/100
17/17 [==============================] - 0s 12ms/step - loss: 0.0804 - accuracy: 0.9866
Epoch 68/100
17/17 [==============================] - 0s 5ms/step - loss: 0.1046 - accuracy: 0.9705
Epoch 69/100
17/17 [==============================] - 0s 11ms/step - loss: 0.1176 - accuracy: 0.9876
Epoch 70/100
17/17 [==============================] - 0s 10ms/step - loss: 0.0650 - accuracy: 0.9790
Epoch 71/100
17/17 [==============================] - 0s 4ms/step - loss: 0.0615 - accuracy: 0.9845
Epoch 72/100
17/17 [==============================] - 0s 4ms/step - loss: 0.0356 - accuracy: 0.9892
Epoch 73/100
17/17 [==============================] - 0s 4ms/step - loss: 0.0525 - accuracy: 0.9867
Epoch 74/100
17/17 [==============================] - 0s 4ms/step - loss: 0.0563 - accuracy: 0.9881
Epoch 75/100
17/17 [==============================] - 0s 12ms/step - loss: 0.1009 - accuracy: 0.9857
Epoch 76/100
17/17 [==============================] - 0s 7ms/step - loss: 0.0745 - accuracy: 0.9777
Epoch 77/100
17/17 [==============================] - 0s 6ms/step - loss: 0.0661 - accuracy: 0.9918
Epoch 78/100
17/17 [==============================] - 0s 5ms/step - loss: 0.0975 - accuracy: 0.9866
Epoch 79/100
17/17 [==============================] - 0s 4ms/step - loss: 0.0614 - accuracy: 0.9855
Epoch 80/100
17/17 [==============================] - 0s 7ms/step - loss: 0.1221 - accuracy: 0.9839
Epoch 81/100
17/17 [==============================] - 0s 14ms/step - loss: 0.1260 - accuracy: 0.9813
Epoch 82/100
17/17 [==============================] - 0s 6ms/step - loss: 0.0794 - accuracy: 0.9859
Epoch 83/100
17/17 [==============================] - 0s 15ms/step - loss: 0.0869 - accuracy: 0.9808
Epoch 84/100
17/17 [==============================] - 0s 7ms/step - loss: 0.0592 - accuracy: 0.9785
Epoch 85/100
17/17 [==============================] - 0s 5ms/step - loss: 0.0694 - accuracy: 0.9800
Epoch 86/100
17/17 [==============================] - 0s 5ms/step - loss: 0.0580 - accuracy: 0.9898
Epoch 87/100
17/17 [==============================] - 0s 8ms/step - loss: 0.1438 - accuracy: 0.9776
Epoch 88/100
17/17 [==============================] - 0s 7ms/step - loss: 0.0360 - accuracy: 0.9934
Epoch 89/100
17/17 [==============================] - 0s 8ms/step - loss: 0.0532 - accuracy: 0.9853
Epoch 90/100
17/17 [==============================] - 0s 6ms/step - loss: 0.0789 - accuracy: 0.9829
Epoch 91/100
17/17 [==============================] - 0s 5ms/step - loss: 0.0388 - accuracy: 0.9878
Epoch 92/100
17/17 [==============================] - 0s 4ms/step - loss: 0.0351 - accuracy: 0.9920
Epoch 93/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0737 - accuracy: 0.9824
Epoch 94/100
17/17 [==============================] - 0s 6ms/step - loss: 0.0392 - accuracy: 0.9887
Epoch 95/100
17/17 [==============================] - 0s 6ms/step - loss: 0.0585 - accuracy: 0.9813
Epoch 96/100
17/17 [==============================] - 0s 12ms/step - loss: 0.0737 - accuracy: 0.9826
Epoch 97/100
17/17 [==============================] - 0s 21ms/step - loss: 0.0706 - accuracy: 0.9802
Epoch 98/100
17/17 [==============================] - 0s 6ms/step - loss: 0.0372 - accuracy: 0.9891
Epoch 99/100
17/17 [==============================] - 0s 9ms/step - loss: 0.0328 - accuracy: 0.9884
Epoch 100/100
17/17 [==============================] - 0s 13ms/step - loss: 0.1277 - accuracy: 0.9791
Out[ ]:
<tensorflow.python.keras.callbacks.History at 0x7f3d017b0d50>
In [ ]:
# Print out model performance
preds_train = donor_model_1.evaluate(input_train, target_train)
print("Train loss, train acc:", preds_train)

preds_test = donor_model_1.evaluate(input_test, target_test)
print("Test loss, test acc:", preds_test)

print("train %s: %.2f%%" % (donor_model_1.metrics_names[1], preds_train[1]*100))
print("test %s: %.2f%%" % (donor_model_1.metrics_names[1], preds_test[1]*100))
17/17 [==============================] - 0s 8ms/step - loss: 0.0626 - accuracy: 0.9850
Train loss, train acc: [0.06257487088441849, 0.9849624037742615]
5/5 [==============================] - 0s 9ms/step - loss: 0.0435 - accuracy: 0.9805
Test loss, test acc: [0.043476205319166183, 0.9805194735527039]
train accuracy: 98.50%
test accuracy: 98.05%

As can be seen in the output above, this simple, non-sequential model performs quite well, yielding a high train and test accuracy. This suggests that with a high degree of accuracy, we can predict based on an individual measurement of relative abundances and presence of perturbations, whether a mouse had a healthy donor or one with UC.

As noted earlier in this code report, an important nuance of this dataset is that it is a time series. Each measurement of ASV relative abundances in mice does not exist in isolation, but rather is related to those measurements that came before and after. To account for this data dynamic, we will next explore how this modeling task can be accomplished using an LSTM.

4.2.2 LSTM (Sequential Model) with Bootstrap

LSTMs are suited for time series task because they utilize past observations as part of the training process. Each mouse has a unique history of measurements, so in our next modeling task we will incorporate the entire mouse history into training.

In this model, each mouse will be its own unit of observation (similar to how we would treat sentences in language modeling), with its sequence of ASV measurements comprising its history.

It's worth noting that we have a very small sample size. With only nine mice total (four with healthy donors and five with UC donors), this presents a challenge insofar as our model will not have many mice to train on, and will also have few mice on which to evaluate its performance.

With this limitation in mind, we decided to take a bootstrap approach to assess model training and performance with more variation. With each bootstrap iteration, we randomly select two mice (one from the healthy cohort and one from the UC cohort) to use as our test group, and use the others for the training group. We then train an LSTM using the mice's entire history of relative abundance measurements and perturbation indicators at each timestep, with a goal of predicting its donor status. Then, we evaluate the overall performance across all bootstrapped iterations.

In [ ]:
# Shape of data for LSTM should be 
# (number of mice x number of observations x number of features)

# Use min observations as number of observations
min_observations = full_dat.groupby('subject').count().min()[0]

healthy_mice = [2,3,4,5]
uc_mice = [6,7,8,9,10]
all_mice = set(healthy_mice+uc_mice)
In [ ]:
# Define a helper function that creates and compiles a LSTM
def make_lstm():
    model_lstm_boot = keras.Sequential(
        [
        Input(shape=(74, 889)),
        layers.LSTM(50, activation = 'relu'),
        layers.Dense(200, activation="relu", name="layer1"),
        layers.Dense(64, activation="relu", name="layer1.5"),
        layers.Dense(10, activation="relu", name="layer2"),
        layers.Dense(1, name="layer3"),
        ]
    )
    model_lstm_boot.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model_lstm_boot
In [ ]:
# Check architecture
test_lstm = make_lstm()
test_lstm.summary()
Model: "sequential_23"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
lstm_23 (LSTM)               (None, 50)                188000    
_________________________________________________________________
layer1 (Dense)               (None, 200)               10200     
_________________________________________________________________
layer1.5 (Dense)             (None, 64)                12864     
_________________________________________________________________
layer2 (Dense)               (None, 10)                650       
_________________________________________________________________
layer3 (Dense)               (None, 1)                 11        
=================================================================
Total params: 211,725
Trainable params: 211,725
Non-trainable params: 0
_________________________________________________________________
In [ ]:
# Perform bootstrap
# Takes a while but should run with greater n (maybe 50)

num_boot = 50
train_acc = []
test_acc = []

# In each boot randomly select 1 healthy and one unhealthy mouse for test set
# Use rest for training
for boot in range(num_boot):
    test_mice = []
    for group in [healthy_mice, uc_mice]:
        test_mouse = random.choice(group)
        test_mice.append(test_mouse)
    #print(test_mice)
    test_mice_set = set(test_mice)
    train_mice = [x for x in all_mice if x not in test_mice_set]
    #print(train_mice)

    # Set up train and test data
    test_dat_boot = full_dat.loc[full_dat['subject'].isin(test_mice)]
    train_dat_boot= full_dat.loc[full_dat['subject'].isin(train_mice)]

    x_train_boot = []
    y_train_boot = []
    x_test_boot = []
    y_test_boot = []
    for data in [train_dat_boot, test_dat_boot]:
        for mouse_number in data['subject'].unique():
            mouse_data = data.loc[data['subject'] == mouse_number].sort_values('time')

            # filter out columns
            if data.equals(train_dat_boot):
                y_train_boot.append(mouse_data['donor'].iloc[0])
            else:
                y_test_boot.append(mouse_data['donor'].iloc[0])
            filtered_mouse_data = mouse_data[[c for c in full_dat.columns if 'ASV' in c or 'perturbation' in c]]
            #print(filtered_mouse_data.shape)
        
            padded_mouse_data = filtered_mouse_data[:min_observations] #padded_mouse_data[0:min_observations]
            #print(padded_mouse_data.shape)
            if data.equals(train_dat_boot):
                x_train_boot.append(padded_mouse_data)
            else:
                x_test_boot.append(padded_mouse_data)
    x_train_boot_np = np.array(x_train_boot)
    y_train_boot_np = np.array(y_train_boot)

    x_test_boot_np = np.array(x_test_boot)
    y_test_boot_np = np.array(y_test_boot)

    #make_and_fit_lstm(x_train_boot_np, y_train_boot_np)
    model_lstm_boot = make_lstm()
    model_lstm_boot.fit(x_train_boot_np, y_train_boot_np, epochs=100, verbose = False)

    preds_train_lstm_boot = model_lstm_boot.evaluate(x_train_boot_np, y_train_boot_np)
    print("Train loss, train acc:", preds_train_lstm_boot)

    preds_test_lstm_boot = model_lstm_boot.evaluate(x_test_boot_np, y_test_boot_np)
    print("Test loss, test acc:", preds_test_lstm_boot)

    train_acc.append(preds_train_lstm_boot[1])
    test_acc.append(preds_test_lstm_boot[1])
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cfa6ece60> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 209ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Train loss, train acc: [0.0, 1.0]
1/1 [==============================] - 0s 30ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Test loss, test acc: [0.0, 1.0]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cfa2027a0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 186ms/step - loss: 0.0033 - accuracy: 1.0000
Train loss, train acc: [0.0033208136446774006, 1.0]
1/1 [==============================] - 0s 23ms/step - loss: 0.0039 - accuracy: 1.0000
Test loss, test acc: [0.003874282818287611, 1.0]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cf9040440> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 186ms/step - loss: 8.8143 - accuracy: 0.4286
Train loss, train acc: [8.81425666809082, 0.4285714328289032]
1/1 [==============================] - 0s 29ms/step - loss: 7.7125 - accuracy: 0.5000
Test loss, test acc: [7.712474346160889, 0.5]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cf8c32200> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 196ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Train loss, train acc: [0.0, 1.0]
1/1 [==============================] - 0s 18ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Test loss, test acc: [0.0, 1.0]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cf59c5b90> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 198ms/step - loss: 8.8143 - accuracy: 0.4286
Train loss, train acc: [8.81425666809082, 0.4285714328289032]
1/1 [==============================] - 0s 18ms/step - loss: 7.7125 - accuracy: 0.5000
Test loss, test acc: [7.712474346160889, 0.5]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cef3b4830> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 188ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Train loss, train acc: [0.0, 1.0]
1/1 [==============================] - 0s 19ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Test loss, test acc: [0.0, 1.0]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cedb974d0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 238ms/step - loss: 8.8143 - accuracy: 0.4286
Train loss, train acc: [8.81425666809082, 0.4285714328289032]
1/1 [==============================] - 0s 19ms/step - loss: 7.7125 - accuracy: 0.5000
Test loss, test acc: [7.712474346160889, 0.5]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cec2ffef0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 192ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Train loss, train acc: [0.0, 1.0]
1/1 [==============================] - 0s 19ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Test loss, test acc: [0.0, 1.0]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cfa8bf830> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 231ms/step - loss: 8.8143 - accuracy: 0.4286
Train loss, train acc: [8.81425666809082, 0.4285714328289032]
1/1 [==============================] - 0s 18ms/step - loss: 7.7125 - accuracy: 0.5000
Test loss, test acc: [7.712474346160889, 0.5]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cfaf6a710> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 187ms/step - loss: 8.8143 - accuracy: 0.4286
Train loss, train acc: [8.81425666809082, 0.4285714328289032]
1/1 [==============================] - 0s 17ms/step - loss: 7.7125 - accuracy: 0.5000
Test loss, test acc: [7.712474346160889, 0.5]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cf7f96200> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 185ms/step - loss: 8.8143 - accuracy: 0.4286
Train loss, train acc: [8.81425666809082, 0.4285714328289032]
1/1 [==============================] - 0s 18ms/step - loss: 7.7125 - accuracy: 0.5000
Test loss, test acc: [7.712474346160889, 0.5]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cffcafd40> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 185ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Train loss, train acc: [0.0, 1.0]
1/1 [==============================] - 0s 18ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Test loss, test acc: [0.0, 1.0]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cf4151f80> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 187ms/step - loss: 8.8143 - accuracy: 0.4286
Train loss, train acc: [8.81425666809082, 0.4285714328289032]
1/1 [==============================] - 0s 18ms/step - loss: 7.7125 - accuracy: 0.5000
Test loss, test acc: [7.712474346160889, 0.5]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cf904bb90> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 193ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Train loss, train acc: [0.0, 1.0]
1/1 [==============================] - 0s 19ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Test loss, test acc: [0.0, 1.0]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cf8f17830> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 191ms/step - loss: 8.8143 - accuracy: 0.4286
Train loss, train acc: [8.81425666809082, 0.4285714328289032]
1/1 [==============================] - 0s 23ms/step - loss: 7.7125 - accuracy: 0.5000
Test loss, test acc: [7.712474346160889, 0.5]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cf9f18200> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 190ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Train loss, train acc: [0.0, 1.0]
1/1 [==============================] - 0s 18ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Test loss, test acc: [0.0, 1.0]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cf9079ef0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 195ms/step - loss: 0.0034 - accuracy: 1.0000
Train loss, train acc: [0.003417105646803975, 1.0]
1/1 [==============================] - 0s 18ms/step - loss: 0.0040 - accuracy: 1.0000
Test loss, test acc: [0.003986623138189316, 1.0]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cfa694c20> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 190ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Train loss, train acc: [0.0, 1.0]
1/1 [==============================] - 0s 23ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Test loss, test acc: [0.0, 1.0]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cf9ea5710> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 200ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Train loss, train acc: [0.0, 1.0]
1/1 [==============================] - 0s 18ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Test loss, test acc: [0.0, 1.0]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cf4dd4290> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 190ms/step - loss: 8.8143 - accuracy: 0.4286
Train loss, train acc: [8.81425666809082, 0.4285714328289032]
1/1 [==============================] - 0s 18ms/step - loss: 7.7125 - accuracy: 0.5000
Test loss, test acc: [7.712474346160889, 0.5]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cec47dcb0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 196ms/step - loss: 8.8143 - accuracy: 0.4286
Train loss, train acc: [8.81425666809082, 0.4285714328289032]
1/1 [==============================] - 0s 18ms/step - loss: 7.7125 - accuracy: 0.5000
Test loss, test acc: [7.712474346160889, 0.5]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cedb7cdd0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 190ms/step - loss: 8.8143 - accuracy: 0.4286
Train loss, train acc: [8.81425666809082, 0.4285714328289032]
1/1 [==============================] - 0s 19ms/step - loss: 7.7125 - accuracy: 0.5000
Test loss, test acc: [7.712474346160889, 0.5]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cf7d9c0e0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 191ms/step - loss: 8.8143 - accuracy: 0.4286
Train loss, train acc: [8.81425666809082, 0.4285714328289032]
1/1 [==============================] - 0s 18ms/step - loss: 7.7125 - accuracy: 0.5000
Test loss, test acc: [7.712474346160889, 0.5]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cf8f7f830> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 194ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Train loss, train acc: [0.0, 1.0]
1/1 [==============================] - 0s 18ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Test loss, test acc: [0.0, 1.0]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cfe825830> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 186ms/step - loss: 8.8143 - accuracy: 0.4286
Train loss, train acc: [8.81425666809082, 0.4285714328289032]
1/1 [==============================] - 0s 17ms/step - loss: 7.7125 - accuracy: 0.5000
Test loss, test acc: [7.712474346160889, 0.5]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cfa47c4d0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 194ms/step - loss: 8.8143 - accuracy: 0.4286
Train loss, train acc: [8.81425666809082, 0.4285714328289032]
1/1 [==============================] - 0s 17ms/step - loss: 7.7125 - accuracy: 0.5000
Test loss, test acc: [7.712474346160889, 0.5]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cf8a2e4d0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 201ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Train loss, train acc: [0.0, 1.0]
1/1 [==============================] - 0s 18ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Test loss, test acc: [0.0, 1.0]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cfa0b84d0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 194ms/step - loss: 8.8143 - accuracy: 0.4286
Train loss, train acc: [8.81425666809082, 0.4285714328289032]
1/1 [==============================] - 0s 17ms/step - loss: 7.7125 - accuracy: 0.5000
Test loss, test acc: [7.712474346160889, 0.5]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cf8e72050> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 185ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Train loss, train acc: [0.0, 1.0]
1/1 [==============================] - 0s 18ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Test loss, test acc: [0.0, 1.0]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cedd68200> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 221ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Train loss, train acc: [0.0, 1.0]
1/1 [==============================] - 0s 18ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Test loss, test acc: [0.0, 1.0]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cf9fefef0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 201ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Train loss, train acc: [0.0, 1.0]
1/1 [==============================] - 0s 18ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Test loss, test acc: [0.0, 1.0]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cedbc9c20> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 243ms/step - loss: 8.8143 - accuracy: 0.4286
Train loss, train acc: [8.81425666809082, 0.4285714328289032]
1/1 [==============================] - 0s 27ms/step - loss: 7.7125 - accuracy: 0.5000
Test loss, test acc: [7.712474346160889, 0.5]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cf8c85dd0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 189ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Train loss, train acc: [0.0, 1.0]
1/1 [==============================] - 0s 19ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Test loss, test acc: [0.0, 1.0]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cfa189a70> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 193ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Train loss, train acc: [0.0, 1.0]
1/1 [==============================] - 0s 18ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Test loss, test acc: [0.0, 1.0]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cec3ac710> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 188ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Train loss, train acc: [0.0, 1.0]
1/1 [==============================] - 0s 19ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Test loss, test acc: [0.0, 1.0]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3ce5554200> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 191ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Train loss, train acc: [0.0, 1.0]
1/1 [==============================] - 0s 18ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Test loss, test acc: [0.0, 1.0]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cffeff950> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 185ms/step - loss: 8.8143 - accuracy: 0.4286
Train loss, train acc: [8.81425666809082, 0.4285714328289032]
1/1 [==============================] - 0s 18ms/step - loss: 7.7125 - accuracy: 0.5000
Test loss, test acc: [7.712474346160889, 0.5]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cedb7cb00> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 184ms/step - loss: 8.8143 - accuracy: 0.4286
Train loss, train acc: [8.81425666809082, 0.4285714328289032]
1/1 [==============================] - 0s 17ms/step - loss: 7.7125 - accuracy: 0.5000
Test loss, test acc: [7.712474346160889, 0.5]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3d03d43cb0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 192ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Train loss, train acc: [0.0, 1.0]
1/1 [==============================] - 0s 18ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Test loss, test acc: [0.0, 1.0]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cedb78cb0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 196ms/step - loss: 8.8143 - accuracy: 0.4286
Train loss, train acc: [8.81425666809082, 0.4285714328289032]
1/1 [==============================] - 0s 18ms/step - loss: 7.7125 - accuracy: 0.5000
Test loss, test acc: [7.712474346160889, 0.5]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cee7f7cb0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 191ms/step - loss: 8.8143 - accuracy: 0.4286
Train loss, train acc: [8.81425666809082, 0.4285714328289032]
1/1 [==============================] - 0s 20ms/step - loss: 7.7125 - accuracy: 0.5000
Test loss, test acc: [7.712474346160889, 0.5]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cec3f5e60> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 193ms/step - loss: 8.8143 - accuracy: 0.4286
Train loss, train acc: [8.81425666809082, 0.4285714328289032]
1/1 [==============================] - 0s 18ms/step - loss: 7.7125 - accuracy: 0.5000
Test loss, test acc: [7.712474346160889, 0.5]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cfa4c9b00> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 208ms/step - loss: 8.8143 - accuracy: 0.4286
Train loss, train acc: [8.81425666809082, 0.4285714328289032]
1/1 [==============================] - 0s 18ms/step - loss: 7.7125 - accuracy: 0.5000
Test loss, test acc: [7.712474346160889, 0.5]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3d00d31cb0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 197ms/step - loss: 8.8143 - accuracy: 0.4286
Train loss, train acc: [8.81425666809082, 0.4285714328289032]
1/1 [==============================] - 0s 18ms/step - loss: 7.7125 - accuracy: 0.5000
Test loss, test acc: [7.712474346160889, 0.5]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cf8cbd950> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 223ms/step - loss: 8.8143 - accuracy: 0.4286
Train loss, train acc: [8.81425666809082, 0.4285714328289032]
1/1 [==============================] - 0s 28ms/step - loss: 7.7125 - accuracy: 0.5000
Test loss, test acc: [7.712474346160889, 0.5]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cfa0735f0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 183ms/step - loss: 8.8143 - accuracy: 0.4286
Train loss, train acc: [8.81425666809082, 0.4285714328289032]
1/1 [==============================] - 0s 25ms/step - loss: 7.7125 - accuracy: 0.5000
Test loss, test acc: [7.712474346160889, 0.5]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cef39a7a0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 181ms/step - loss: 8.8143 - accuracy: 0.4286
Train loss, train acc: [8.81425666809082, 0.4285714328289032]
1/1 [==============================] - 0s 20ms/step - loss: 7.7125 - accuracy: 0.5000
Test loss, test acc: [7.712474346160889, 0.5]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3cf8ac6440> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 194ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Train loss, train acc: [0.0, 1.0]
1/1 [==============================] - 0s 19ms/step - loss: 0.0000e+00 - accuracy: 1.0000
Test loss, test acc: [0.0, 1.0]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3ceac2f8c0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 193ms/step - loss: 8.8143 - accuracy: 0.4286
Train loss, train acc: [8.81425666809082, 0.4285714328289032]
1/1 [==============================] - 0s 18ms/step - loss: 7.7125 - accuracy: 0.5000
Test loss, test acc: [7.712474346160889, 0.5]
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3d023ffdd0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 196ms/step - loss: 8.8143 - accuracy: 0.4286
Train loss, train acc: [8.81425666809082, 0.4285714328289032]
1/1 [==============================] - 0s 17ms/step - loss: 7.7125 - accuracy: 0.5000
Test loss, test acc: [7.712474346160889, 0.5]
In [ ]:
# Quick histogram
# Since there are only two mice in test group, it
# Can either get 0, 50% or 100% accuracy.
plt.hist(test_acc)  # bins=30
plt.ylabel('Count')
plt.xlabel('Accuracy')
plt.title('Accuracy Scores Across Bootstraps')
Out[ ]:
Text(0.5, 1.0, 'Accuracy Scores Across Bootstraps')
In [ ]:
# Print out average performance
print(f'Avg. Train Accuracy Across Bootstraps: {round(np.mean(train_acc),4)}')
print(f'Avg. Test Accuracy Across Bootstraps: {round(np.mean(test_acc),4)}')
Avg. Train Accuracy Across Bootstraps: 0.68
Avg. Test Accuracy Across Bootstraps: 0.72

As can be seen by the output above, the LSTM model trained on each mouse's entire history does not perform as well as the previous vanilla non-sequential model, with a lower average test accuracy score across bootstraps. However, the average test accuracy across bootstraps is better than what we would expect via random choice (50%), suggesting that the model is able to learn and predict to some level of reliability despite the tiny sample size. This is also visible from the histogram of bootstrapped test accuracy scores. Because there are only two mice in each bootstrapped test group, the accuracy scores can be either 0%, 50%, or 100%. There are nearly as many 100% bootstrapped test accuracy scores as 50%, which indicates positive test performance in many of the bootstrapped iterations.

It would be interesting to try this sequential method in scenarios with many more mice to evaluate its relative performance to the non-sequential model.

4.2.3 Multivariate LSTM (Sequential Model) using Phylum Relative Abundance

Similar to the previous LSTM model, each mouse represents a sample which means that we are dealing with a very small total sample size of 9 with a held out test set of 2 mice: 1 with a healthy donor, and 1 with a UC donor.

The multivariate LSTM model is trained on time-series data with each time point containing multiple features. These features correspond to the relative abundance of every phylum. By aggregating the data by phylum, we hope to realize general patterns in the time-series data that are indicitive of donor status (microbiome derived from a patient with UC or a healthy donor.)

In [ ]:
# Group the counts dataframe by class and compute the relative abundance per phylum
phylum_abundance = counts.assign(Class = asv_and_taxonomy['Phylum'].to_numpy()).groupby('Class').sum()
phylum_abundance = phylum_abundance.transpose()

# Normalize every row by the row sum to create relative abundance values
rowsums = phylum_abundance.sum(axis=1)
phylum_abundance = phylum_abundance.div(rowsums, axis='index')

# Specify the mouse ID for samples in the metadata
donor_data = metadata.copy()
donor_data['mouseID'] = [x.split('-')[0] for x in donor_data['sampleID']]
# Group the data by mouse ID and order the data in chronological order
donor_data = donor_data.sort_values(['mouseID', 'time'], axis=0, ascending=True)

# Sort the phylum abundance data based on the sample ID order found in the donor data
sorted_indices = [np.where(phylum_abundance.index == element)[0][0] for element in donor_data['sampleID']]
phylum_abundance = phylum_abundance.iloc[sorted_indices,:]
In [ ]:
# Define the response (the donor status)
donor_data['donor_status'] = np.where(donor_data['subject'].isin([6, 7, 8, 9, 10]), 1, 0)

# Specify the training and test sets (the test set consists of one healthy mouse 
train = phylum_abundance.loc[donor_data['mouseID'].astype(int).isin([3,4,5,7,8,9,10]).to_numpy(),:].to_numpy()
test = phylum_abundance.loc[donor_data['mouseID'].astype(int).isin([2,6]).to_numpy(),:].to_numpy()
donor_train = donor_data.loc[donor_data['mouseID'].astype(int).isin([3,4,5,7,8,9,10]).to_numpy(),:]
donor_test = donor_data.loc[donor_data['mouseID'].astype(int).isin([2,6]).to_numpy(),:]

print(f"The shape of X train and X test are {train.shape} and {test.shape} respectively.")
The shape of X train and X test are (532, 11) and (154, 11) respectively.
In [ ]:
# Define the predictors and response for the training set
X_train,y_train = [],[]
for mouse in donor_train['mouseID'].unique():
    # Pull the data associated with the given mouse
    training_data = train[donor_train['mouseID'] == mouse,:]
    training_response = donor_train.loc[donor_train['mouseID'] == mouse,:]['donor_status'].unique()[0]
    
    X_train.append(training_data)
    y_train.append(training_response)

# Define the predictors and response for the test set
X_test,y_test = [],[]
for mouse in donor_test['mouseID'].unique():
    # Pull the data associated with the given mouse
    test_data = test[donor_test['mouseID'] == mouse,:]
    test_response = donor_test.loc[donor_test['mouseID'] == mouse,:]['donor_status'].unique()[0]
    
    X_test.append(test_data)
    y_test.append(test_response)
In [ ]:
# Pad the the training and test set given that the mice contributed varying numbers of samples

# Determine the max number of samples contributed
max_samples = np.max([x.shape[0] for x in X_train])

# Pad the training set
for i in range(len(X_train)):
    data = X_train[i]
    # Define the padding
    padding = np.zeros((max_samples-data.shape[0], data.shape[1]))
    # Add the padding
    data = np.append(data, padding, axis=0)
    X_train[i] = data.copy()

# Pad the test set
for i in range(len(X_test)):
    data = X_test[i]
    # Define the padding
    padding = np.zeros((max_samples-data.shape[0], data.shape[1]))
    # Add the padding
    data = np.append(data, padding, axis=0)
    X_test[i] = data.copy()

X_train = np.array(X_train)
X_test = np.array(X_test)
y_train = np.array(y_train)
y_test = np.array(y_test)
In [ ]:
# Define a function used to create the LSTM model
def create_LSTM(input_size, hidden_states):
    model = Sequential()
    model.add(Input(shape=input_size))
    model.add(LSTM(hidden_states, activation = 'relu'))
    model.add(Dense(256, activation='relu'))
    model.add(Dense(128, activation='relu'))
    model.add(Dense(64, activation='relu'))
    model.add(Dense(32, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))

    return model
In [ ]:
# Define the model
model = create_LSTM(X_train[0].shape, 64)

# Compile the model
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# Train the model
history = model.fit(X_train, y_train, epochs=100)
Epoch 1/100
1/1 [==============================] - 1s 1s/step - loss: 0.6926 - accuracy: 0.5714
Epoch 2/100
1/1 [==============================] - 0s 32ms/step - loss: 0.6917 - accuracy: 0.5714
Epoch 3/100
1/1 [==============================] - 0s 35ms/step - loss: 0.6909 - accuracy: 0.5714
Epoch 4/100
1/1 [==============================] - 0s 26ms/step - loss: 0.6901 - accuracy: 0.5714
Epoch 5/100
1/1 [==============================] - 0s 24ms/step - loss: 0.6891 - accuracy: 0.5714
Epoch 6/100
1/1 [==============================] - 0s 32ms/step - loss: 0.6880 - accuracy: 0.5714
Epoch 7/100
1/1 [==============================] - 0s 26ms/step - loss: 0.6869 - accuracy: 0.5714
Epoch 8/100
1/1 [==============================] - 0s 26ms/step - loss: 0.6856 - accuracy: 0.5714
Epoch 9/100
1/1 [==============================] - 0s 45ms/step - loss: 0.6843 - accuracy: 0.5714
Epoch 10/100
1/1 [==============================] - 0s 26ms/step - loss: 0.6829 - accuracy: 0.5714
Epoch 11/100
1/1 [==============================] - 0s 27ms/step - loss: 0.6815 - accuracy: 0.5714
Epoch 12/100
1/1 [==============================] - 0s 49ms/step - loss: 0.6801 - accuracy: 0.5714
Epoch 13/100
1/1 [==============================] - 0s 25ms/step - loss: 0.6788 - accuracy: 0.5714
Epoch 14/100
1/1 [==============================] - 0s 25ms/step - loss: 0.6776 - accuracy: 0.5714
Epoch 15/100
1/1 [==============================] - 0s 38ms/step - loss: 0.6767 - accuracy: 0.5714
Epoch 16/100
1/1 [==============================] - 0s 28ms/step - loss: 0.6754 - accuracy: 0.5714
Epoch 17/100
1/1 [==============================] - 0s 45ms/step - loss: 0.6740 - accuracy: 0.5714
Epoch 18/100
1/1 [==============================] - 0s 29ms/step - loss: 0.6723 - accuracy: 0.5714
Epoch 19/100
1/1 [==============================] - 0s 29ms/step - loss: 0.6701 - accuracy: 0.5714
Epoch 20/100
1/1 [==============================] - 0s 45ms/step - loss: 0.6675 - accuracy: 0.5714
Epoch 21/100
1/1 [==============================] - 0s 28ms/step - loss: 0.6641 - accuracy: 0.5714
Epoch 22/100
1/1 [==============================] - 0s 34ms/step - loss: 0.6604 - accuracy: 0.5714
Epoch 23/100
1/1 [==============================] - 0s 44ms/step - loss: 0.6558 - accuracy: 0.5714
Epoch 24/100
1/1 [==============================] - 0s 29ms/step - loss: 0.6499 - accuracy: 0.5714
Epoch 25/100
1/1 [==============================] - 0s 43ms/step - loss: 0.6419 - accuracy: 0.5714
Epoch 26/100
1/1 [==============================] - 0s 27ms/step - loss: 0.6321 - accuracy: 0.5714
Epoch 27/100
1/1 [==============================] - 0s 36ms/step - loss: 0.6194 - accuracy: 0.5714
Epoch 28/100
1/1 [==============================] - 0s 45ms/step - loss: 0.6019 - accuracy: 0.5714
Epoch 29/100
1/1 [==============================] - 0s 32ms/step - loss: 0.5778 - accuracy: 0.5714
Epoch 30/100
1/1 [==============================] - 0s 49ms/step - loss: 0.5464 - accuracy: 0.5714
Epoch 31/100
1/1 [==============================] - 0s 57ms/step - loss: 0.5024 - accuracy: 0.7143
Epoch 32/100
1/1 [==============================] - 0s 34ms/step - loss: 0.4440 - accuracy: 0.7143
Epoch 33/100
1/1 [==============================] - 0s 28ms/step - loss: 0.3643 - accuracy: 0.7143
Epoch 34/100
1/1 [==============================] - 0s 27ms/step - loss: 0.2850 - accuracy: 1.0000
Epoch 35/100
1/1 [==============================] - 0s 46ms/step - loss: 0.2633 - accuracy: 1.0000
Epoch 36/100
1/1 [==============================] - 0s 29ms/step - loss: 0.2503 - accuracy: 1.0000
Epoch 37/100
1/1 [==============================] - 0s 28ms/step - loss: 0.2379 - accuracy: 1.0000
Epoch 38/100
1/1 [==============================] - 0s 40ms/step - loss: 0.2253 - accuracy: 1.0000
Epoch 39/100
1/1 [==============================] - 0s 26ms/step - loss: 0.2115 - accuracy: 1.0000
Epoch 40/100
1/1 [==============================] - 0s 27ms/step - loss: 0.1956 - accuracy: 1.0000
Epoch 41/100
1/1 [==============================] - 0s 51ms/step - loss: 0.1771 - accuracy: 1.0000
Epoch 42/100
1/1 [==============================] - 0s 28ms/step - loss: 0.1555 - accuracy: 1.0000
Epoch 43/100
1/1 [==============================] - 0s 48ms/step - loss: 0.1307 - accuracy: 1.0000
Epoch 44/100
1/1 [==============================] - 0s 28ms/step - loss: 0.1022 - accuracy: 1.0000
Epoch 45/100
1/1 [==============================] - 0s 28ms/step - loss: 0.0702 - accuracy: 1.0000
Epoch 46/100
1/1 [==============================] - 0s 47ms/step - loss: 0.0361 - accuracy: 1.0000
Epoch 47/100
1/1 [==============================] - 0s 31ms/step - loss: 2.0836 - accuracy: 0.5714
Epoch 48/100
1/1 [==============================] - 0s 38ms/step - loss: 1.8022 - accuracy: 0.4286
Epoch 49/100
1/1 [==============================] - 0s 31ms/step - loss: 1.6913 - accuracy: 0.4286
Epoch 50/100
1/1 [==============================] - 0s 27ms/step - loss: 1.5617 - accuracy: 0.4286
Epoch 51/100
1/1 [==============================] - 0s 47ms/step - loss: 1.4092 - accuracy: 0.4286
Epoch 52/100
1/1 [==============================] - 0s 27ms/step - loss: 1.2577 - accuracy: 0.4286
Epoch 53/100
1/1 [==============================] - 0s 30ms/step - loss: 1.1201 - accuracy: 0.4286
Epoch 54/100
1/1 [==============================] - 0s 47ms/step - loss: 1.0067 - accuracy: 0.4286
Epoch 55/100
1/1 [==============================] - 0s 27ms/step - loss: 0.9197 - accuracy: 0.4286
Epoch 56/100
1/1 [==============================] - 0s 47ms/step - loss: 0.8559 - accuracy: 0.4286
Epoch 57/100
1/1 [==============================] - 0s 28ms/step - loss: 0.8097 - accuracy: 0.4286
Epoch 58/100
1/1 [==============================] - 0s 27ms/step - loss: 0.7760 - accuracy: 0.4286
Epoch 59/100
1/1 [==============================] - 0s 42ms/step - loss: 0.7517 - accuracy: 0.4286
Epoch 60/100
1/1 [==============================] - 0s 28ms/step - loss: 0.7343 - accuracy: 0.4286
Epoch 61/100
1/1 [==============================] - 0s 30ms/step - loss: 0.7212 - accuracy: 0.4286
Epoch 62/100
1/1 [==============================] - 0s 43ms/step - loss: 0.7111 - accuracy: 0.4286
Epoch 63/100
1/1 [==============================] - 0s 28ms/step - loss: 0.7031 - accuracy: 0.4286
Epoch 64/100
1/1 [==============================] - 0s 28ms/step - loss: 0.6965 - accuracy: 0.4286
Epoch 65/100
1/1 [==============================] - 0s 57ms/step - loss: 0.6913 - accuracy: 0.4286
Epoch 66/100
1/1 [==============================] - 0s 29ms/step - loss: 0.6868 - accuracy: 0.4286
Epoch 67/100
1/1 [==============================] - 0s 27ms/step - loss: 0.6827 - accuracy: 0.4286
Epoch 68/100
1/1 [==============================] - 0s 31ms/step - loss: 0.6783 - accuracy: 0.4286
Epoch 69/100
1/1 [==============================] - 0s 26ms/step - loss: 0.6731 - accuracy: 0.4286
Epoch 70/100
1/1 [==============================] - 0s 46ms/step - loss: 0.6660 - accuracy: 0.4286
Epoch 71/100
1/1 [==============================] - 0s 29ms/step - loss: 0.6545 - accuracy: 0.4286
Epoch 72/100
1/1 [==============================] - 0s 26ms/step - loss: 0.6316 - accuracy: 0.4286
Epoch 73/100
1/1 [==============================] - 0s 46ms/step - loss: 0.5718 - accuracy: 0.4286
Epoch 74/100
1/1 [==============================] - 0s 28ms/step - loss: 0.4538 - accuracy: 0.4286
Epoch 75/100
1/1 [==============================] - 0s 32ms/step - loss: 0.4542 - accuracy: 0.4286
Epoch 76/100
1/1 [==============================] - 0s 40ms/step - loss: 0.4588 - accuracy: 0.4286
Epoch 77/100
1/1 [==============================] - 0s 28ms/step - loss: 0.4645 - accuracy: 0.4286
Epoch 78/100
1/1 [==============================] - 0s 48ms/step - loss: 0.4668 - accuracy: 0.4286
Epoch 79/100
1/1 [==============================] - 0s 28ms/step - loss: 0.4586 - accuracy: 0.4286
Epoch 80/100
1/1 [==============================] - 0s 28ms/step - loss: 0.4388 - accuracy: 0.4286
Epoch 81/100
1/1 [==============================] - 0s 46ms/step - loss: 0.4175 - accuracy: 0.4286
Epoch 82/100
1/1 [==============================] - 0s 28ms/step - loss: 0.4009 - accuracy: 0.4286
Epoch 83/100
1/1 [==============================] - 0s 29ms/step - loss: 0.3906 - accuracy: 1.0000
Epoch 84/100
1/1 [==============================] - 0s 37ms/step - loss: 0.3839 - accuracy: 1.0000
Epoch 85/100
1/1 [==============================] - 0s 27ms/step - loss: 0.3787 - accuracy: 1.0000
Epoch 86/100
1/1 [==============================] - 0s 47ms/step - loss: 1.3883 - accuracy: 0.5714
Epoch 87/100
1/1 [==============================] - 0s 27ms/step - loss: 0.6760 - accuracy: 0.5714
Epoch 88/100
1/1 [==============================] - 0s 28ms/step - loss: 0.6859 - accuracy: 0.5714
Epoch 89/100
1/1 [==============================] - 0s 39ms/step - loss: 0.6867 - accuracy: 0.5714
Epoch 90/100
1/1 [==============================] - 0s 27ms/step - loss: 0.6868 - accuracy: 0.5714
Epoch 91/100
1/1 [==============================] - 0s 28ms/step - loss: 0.6868 - accuracy: 0.5714
Epoch 92/100
1/1 [==============================] - 0s 51ms/step - loss: 0.6869 - accuracy: 0.5714
Epoch 93/100
1/1 [==============================] - 0s 30ms/step - loss: 0.6869 - accuracy: 0.5714
Epoch 94/100
1/1 [==============================] - 0s 26ms/step - loss: 0.6868 - accuracy: 0.5714
Epoch 95/100
1/1 [==============================] - 0s 54ms/step - loss: 0.6867 - accuracy: 0.5714
Epoch 96/100
1/1 [==============================] - 0s 37ms/step - loss: 0.6866 - accuracy: 0.5714
Epoch 97/100
1/1 [==============================] - 0s 25ms/step - loss: 0.6865 - accuracy: 0.5714
Epoch 98/100
1/1 [==============================] - 0s 30ms/step - loss: 0.6863 - accuracy: 0.5714
Epoch 99/100
1/1 [==============================] - 0s 27ms/step - loss: 0.6861 - accuracy: 0.5714
Epoch 100/100
1/1 [==============================] - 0s 42ms/step - loss: 0.6859 - accuracy: 0.5714
In [ ]:
train_accuracy = model.evaluate(X_train, y_train)[1]
print(f"The training set accuracy is {train_accuracy:0.4f}")

test_accuracy = model.evaluate(X_test, y_test)[1]
print(f"The training set accuracy is {test_accuracy:0.4f}")
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7f3ceda4c7a0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 190ms/step - loss: 0.6857 - accuracy: 0.5714
The training set accuracy is 0.5714
1/1 [==============================] - 0s 19ms/step - loss: 0.6916 - accuracy: 0.5000
The training set accuracy is 0.5000

5. Results

5.1 Conclusions

  • Using lagged qPCR measurements, it is easier to predict the next measurement for mice in the healthy donor group compared to those in the UC donor group using both autoregressive models and LSTMs, which performed very similarly on both groups. The low $R^2$ values for mice in the UC donor group suggest that their qPCR measurements may be more variable.
  • We can accurately distinguish between mice in the two donor groups using CNNs based on individual measurements of ASV relative abundances and knowledge of perturbations. This indicates that the composition of bacteria in the mice microbiomes respond very differently to the same perturbations for mice in the two groups.
  • It is difficult to distingush between mice in the two donor groups without knowledge of perturbations, even by training LSTM models on the entire history of ASV relative abundances.
  • Pre-aggregating data, such as by computing the relative abundances per phylum (or some other taxonomic grouping), does not improve predictive performance. Exposure to the full granularity of the data seems to be beneficial even if the data might be noisy.

5.2 Limitations

  • Given the small sample size of nine mice, we did not have enough data to determine whether our LSTM classification models were truly unable to accurately distinguish between mice in the two groups based on ASV relative abundance histories.

5.3 Future Work

  • We could try predicting qPCR measurements using more lagged timesteps.
  • Perhaps reservoir computing would allow us to forecast qPCR measurements based on longer histories, especially given more data.
  • Now that we know that knowledge of ASV relative abundance and perturbations are predictive of mouse donor group, we could explore more deeply how relative abundancies change in response to particular perturbations for mice in the two groups.