Coverage for /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/memilio/surrogatemodel/ode_secir_simple/model.py: 39%
103 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-01-17 11:58 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-01-17 11:58 +0000
1#############################################################################
2# Copyright (C) 2020-2025 MEmilio
3#
4# Authors: Agatha Schmidt, Henrik Zunker
5#
6# Contact: Martin J. Kuehn <Martin.Kuehn@DLR.de>
7#
8# Licensed under the Apache License, Version 2.0 (the "License");
9# you may not use this file except in compliance with the License.
10# You may obtain a copy of the License at
11#
12# http://www.apache.org/licenses/LICENSE-2.0
13#
14# Unless required by applicable law or agreed to in writing, software
15# distributed under the License is distributed on an "AS IS" BASIS,
16# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17# See the License for the specific language governing permissions and
18# limitations under the License.
19#############################################################################
20import os
21import pickle
23import matplotlib.pyplot as plt
24import numpy as np
25import pandas as pd
26import tensorflow as tf
28from memilio.simulation.osecir import InfectionState
29from memilio.surrogatemodel.ode_secir_simple import network_architectures
32def plot_compartment_prediction_model(
33 inputs, labels, model=None, plot_compartment='InfectedSymptoms',
34 max_subplots=8):
35 """! Plot prediction of the model and label for one compartment.
37 If model is none, we just plot the inputs and labels for the selected compartment without any predictions.
39 @param inputs test inputs for model prediction.
40 @param labels test labels.
41 @param model trained model.
42 @param plot_col string name of compartment to be plotted.
43 @param max_subplots Number of the simulation runs to be plotted and compared against.
44 """
46 input_width = inputs.shape[1]
47 label_width = labels.shape[1]
49 plt.figure(figsize=(12, 8))
50 plot_compartment_index = 0
51 for compartment in InfectionState.values():
52 if compartment.name == plot_compartment:
53 break
54 plot_compartment_index += 1
55 if plot_compartment_index == len(InfectionState.values()):
56 raise ValueError('Compartment name given could not be found.')
57 max_n = min(max_subplots, inputs.shape[0])
59 for n in range(max_n):
60 plt.subplot(max_n, 1, n+1)
61 plt.ylabel(plot_compartment)
63 input_array = inputs[n].numpy()
64 label_array = labels[n].numpy()
65 plt.plot(
66 np.arange(0, input_width),
67 input_array[:, plot_compartment_index],
68 label='Inputs', marker='.', zorder=-10)
69 plt.scatter(
70 np.arange(input_width, input_width + label_width),
71 label_array[:, plot_compartment_index],
72 edgecolors='k', label='Labels', c='#2ca02c', s=64)
74 if model is not None:
75 input_series = tf.expand_dims(inputs[n], axis=0)
76 pred = model(input_series)
77 pred = pred.numpy()
78 plt.scatter(np.arange(input_width, input_width+pred.shape[-2]),
79 pred[0, :, plot_compartment_index],
80 marker='X', edgecolors='k', label='Predictions',
81 c='#ff7f0e', s=64)
83 plt.xlabel('days')
84 if os.path.isdir("plots") == False:
85 os.mkdir("plots")
86 plt.savefig('plots/evaluation_secir_simple_' + plot_compartment + '.png')
89def network_fit(path, model, max_epochs=30, early_stop=500, plot=True):
90 """! Training and evaluation of a given model with mean squared error loss and Adam optimizer using the mean absolute error as a metric.
92 @param path path of the dataset.
93 @param model Keras sequential model.
94 @param max_epochs int maximum number of epochs in training.
95 @param early_stop Integer that forces an early stop of training if the given number of epochs does not give a significant reduction of validation loss.
97 """
99 if not os.path.isfile(os.path.join(path, 'data_secir_simple.pickle')):
100 ValueError("no dataset found in path: " + path)
102 file = open(os.path.join(path, 'data_secir_simple.pickle'), 'rb')
104 data = pickle.load(file)
105 data_splitted = split_data(data['inputs'], data['labels'])
107 train_inputs = data_splitted['train_inputs']
108 train_labels = data_splitted['train_labels']
109 valid_inputs = data_splitted['valid_inputs']
110 valid_labels = data_splitted['valid_labels']
111 test_inputs = data_splitted['test_inputs']
112 test_labels = data_splitted['test_labels']
114 early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss',
115 patience=early_stop,
116 mode='min')
118 model.compile(
119 loss=tf.keras.losses.MeanSquaredError(),
120 optimizer=tf.keras.optimizers.Adam(),
121 metrics=[tf.keras.metrics.MeanAbsoluteError()])
123 history = model.fit(train_inputs, train_labels, epochs=max_epochs,
124 validation_data=(valid_inputs, valid_labels),
125 callbacks=[early_stopping])
127 if (plot):
128 plot_losses(history)
129 plot_compartment_prediction_model(
130 test_inputs, test_labels, model=model,
131 plot_compartment='InfectedSymptoms', max_subplots=3)
132 df = get_test_statistic(test_inputs, test_labels, model)
133 print(df)
134 return history
137def plot_losses(history):
138 """! Plots the losses of the model training.
140 @param history model training history.
142 """
143 plt.plot(history.history['loss'])
144 plt.plot(history.history['val_loss'])
145 plt.title('model loss')
146 plt.ylabel('loss')
147 plt.xlabel('epoch')
148 plt.legend(['train', 'val'], loc='upper left')
149 if os.path.isdir("plots") == False:
150 os.mkdir("plots")
151 plt.savefig('plots/losses_plot.png')
152 plt.show()
155def get_test_statistic(test_inputs, test_labels, model):
156 """! Calculates the mean absolute percentage error based on the test dataset.
158 @param test_inputs inputs from test data.
159 @param test_labels labels (output) from test data.
160 @param model trained model.
162 """
164 pred = model(test_inputs)
165 pred = pred.numpy()
166 test_labels = np.array(test_labels)
168 diff = pred - test_labels
169 relative_err = (abs(diff))/abs(test_labels)
170 # reshape [batch, time, features] -> [features, time * batch]
171 relative_err_transformed = relative_err.transpose(2, 0, 1).reshape(8, -1)
172 relative_err_means_percentage = relative_err_transformed.mean(axis=1) * 100
173 mean_percentage = pd.DataFrame(
174 data=relative_err_means_percentage,
175 index=[str(compartment).split('.')[1]
176 for compartment in InfectionState.values()],
177 columns=['Percentage Error'])
179 return mean_percentage
182def split_data(inputs, labels, split_train=0.7,
183 split_valid=0.2, split_test=0.1):
184 """! Split data set in training, validation and testing data sets.
186 @param inputs input dataset
187 @param labels label dataset
188 @param split_train Share of training data sets.
189 @param split_valid Share of validation data sets.
190 @param split_test Share of testing data sets.
191 """
193 if split_train + split_valid + split_test > 1 + 1e-10:
194 raise ValueError(
195 "Summed data set shares are greater than 1. Please adjust the values.")
196 elif inputs.shape[0] != labels.shape[0] or inputs.shape[2] != labels.shape[2]:
197 raise ValueError(
198 "Number of batches or features different for input and labels")
200 n = inputs.shape[0]
201 n_train = int(n * split_train)
202 n_valid = int(n * split_valid)
203 n_test = n - n_train - n_valid
205 inputs_train, inputs_valid, inputs_test = tf.split(
206 inputs, [n_train, n_valid, n_test], 0)
207 labels_train, labels_valid, labels_test = tf.split(
208 labels, [n_train, n_valid, n_test], 0)
210 data = {
211 'train_inputs': inputs_train,
212 'train_labels': labels_train,
213 'valid_inputs': inputs_valid,
214 'valid_labels': labels_valid,
215 'test_inputs': inputs_test,
216 'test_labels': labels_test
217 }
219 return data
222if __name__ == "__main__":
223 path = os.path.dirname(os.path.realpath(__file__))
224 path_data = os.path.join(os.path.dirname(os.path.realpath(
225 os.path.dirname(os.path.realpath(path)))), 'data')
226 max_epochs = 400
228 model = "LSTM"
229 if model == "Dense":
230 model = network_architectures.mlp_multi_input_single_output()
231 elif model == "LSTM":
232 model = network_architectures.lstm_multi_input_multi_output(30)
233 elif model == "CNN":
234 model = network_architectures.cnn_multi_input_multi_output(30)
236 model_output = network_fit(
237 path_data, model=model,
238 max_epochs=max_epochs)