Coverage for /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/memilio/surrogatemodel/ode_secir_groups/model.py: 54%
226 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-01-17 11:58 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-01-17 11:58 +0000
1#############################################################################
2# Copyright (C) 2020-2025 MEmilio
3#
4# Authors: Agatha Schmidt, Henrik Zunker
5#
6# Contact: Martin J. Kuehn <Martin.Kuehn@DLR.de>
7#
8# Licensed under the Apache License, Version 2.0 (the "License");
9# you may not use this file except in compliance with the License.
10# You may obtain a copy of the License at
11#
12# http://www.apache.org/licenses/LICENSE-2.0
13#
14# Unless required by applicable law or agreed to in writing, software
15# distributed under the License is distributed on an "AS IS" BASIS,
16# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17# See the License for the specific language governing permissions and
18# limitations under the License.
19#############################################################################
20from memilio.surrogatemodel.ode_secir_groups import network_architectures
21from memilio.simulation.osecir import InfectionState
22import os
23import pickle
25import matplotlib.pyplot as plt
26import numpy as np
27import pandas as pd
28import tensorflow as tf
31def plot_compartment_prediction_model(
32 inputs, labels, modeltype, model=None,
33 plot_compartment='InfectedSymptoms', max_subplots=8):
34 """! Plot prediction of the model and label for one compartment. The average of all age groups is plotted.
36 If model is none, we just plot the inputs and labels for the selected compartment without any predictions.
38 @param inputs test inputs for model prediction.
39 @param labels test labels.
40 @param modeltype type of model. Can be 'classic' or 'timeseries'
41 @param model trained model.
42 @param plot_col string name of compartment to be plotted.
43 @param max_subplots Number of the simulation runs to be plotted and compared against.
44 """
45 num_groups = 6
46 num_compartments = 8
47 if modeltype == 'classic':
48 # to get the input_width, we first subtract the damping date and the contact matrix entries.
49 # Next, we divide by the number of age groups * the number of compartments
50 input_width = int(
51 (inputs.shape[1] - (1 + num_groups * num_groups)) / (num_groups * num_compartments))
52 elif modeltype == 'timeseries':
53 input_width = int(inputs.shape[1])
54 label_width = int(labels.shape[1])
56 plt.figure(figsize=(12, 8))
57 plot_compartment_index = 0
58 for compartment in InfectionState.values():
59 if compartment.name == plot_compartment:
60 break
61 plot_compartment_index += 1
62 if plot_compartment_index == len(InfectionState.values()):
63 raise ValueError('Compartment name given could not be found.')
64 max_n = min(max_subplots, inputs.shape[0])
66 for n in range(max_n):
67 plt.subplot(max_n, 1, n+1)
68 plt.ylabel(plot_compartment)
70 input_array = inputs[n].numpy()
71 label_array = labels[n].numpy()
73 if modeltype == 'classic':
74 input_plot = input_array[:(
75 input_width*num_groups*num_compartments)]
76 input_plot = input_plot.reshape(
77 input_width, num_groups*num_compartments)
79 mean_per_day_input = []
80 for i in input_plot:
81 x = i[plot_compartment_index::8]
82 mean_per_day_input.append(x.mean())
84 plt.plot(
85 np.arange(0, input_width),
86 mean_per_day_input,
87 label='Inputs', marker='.', zorder=-10)
89 elif modeltype == 'timeseries':
90 mean_per_day_input = []
91 for i in input_array:
92 # Inputs has the dimensions [num_runs, input_width, features].
93 # The features consist of the compartment data, contact matrices and the damping day.
94 # Here, we want to get the mean of the plot_compartment over all age groups. Therefore,
95 # we subtract the damping day and the contact matrix entries.
96 x = i[plot_compartment_index: inputs.shape[2] -
97 (1 + num_groups * num_groups):8]
98 mean_per_day_input.append(x.mean())
100 plt.plot(
101 np.arange(0, input_width),
102 mean_per_day_input,
103 label='Inputs', marker='.', zorder=-10)
105 mean_per_day = []
106 for i in label_array:
107 x = i[plot_compartment_index::8]
108 mean_per_day.append(x.mean())
109 plt.scatter(
110 np.arange(input_width, input_width + label_width),
111 mean_per_day,
112 edgecolors='k', label='Labels', c='#2ca02c', s=64)
114 if model is not None:
115 input_series = tf.expand_dims(inputs[n], axis=0)
116 pred = model(input_series)
117 pred = pred.numpy()
118 pred = pred.reshape((30, 48))
120 mean_per_day_pred = []
121 for i in pred:
122 x = i[plot_compartment_index::8]
123 mean_per_day_pred.append(x.mean())
125 plt.scatter(np.arange(input_width, input_width+pred.shape[-2]),
126 # pred[0, :, plot_compartment_index],
127 mean_per_day_pred,
128 marker='X', edgecolors='k', label='Predictions',
129 c='#ff7f0e', s=64)
131 plt.xlabel('days')
132 plt.legend()
133 if os.path.isdir("plots") == False:
134 os.mkdir("plots")
135 plt.savefig('plots/evaluation_secir_groups_' + plot_compartment + '.png')
138def network_fit(
139 path, model, modeltype, max_epochs=30, early_stop=500, plot=True):
140 """! Training and evaluation of a given model with mean squared error loss and Adam optimizer using the mean absolute error as a metric.
142 @param path path of the dataset.
143 @param model Keras sequential model.
144 @param modeltype type of model. Can be 'classic' or 'timeseries'. Data preparation is made based on the modeltype.
145 @param max_epochs int maximum number of epochs in training.
146 @param early_stop Integer that forces an early stop of training if the given number of epochs does not give a significant reduction of validation loss.
148 """
150 if not os.path.isfile(os.path.join(path, 'data_secir_groups.pickle')):
151 ValueError("no dataset found in path: " + path)
153 file = open(os.path.join(path, 'data_secir_groups.pickle'), 'rb')
155 data = pickle.load(file)
156 data_splitted = split_data(data['inputs'], data['labels'])
158 if modeltype == 'classic':
160 train_inputs_compartments = flat_input(data_splitted["train_inputs"])
161 train_labels = (data_splitted["train_labels"])
162 valid_inputs_compartments = flat_input(data_splitted["valid_inputs"])
163 valid_labels = (data_splitted["valid_labels"])
164 test_inputs_compartments = flat_input(data_splitted["test_inputs"])
165 test_labels = (data_splitted["test_labels"])
167 contact_matrices = split_contact_matrices(
168 tf.stack(data["contact_matrix"]))
169 contact_matrices_train = flat_input(contact_matrices['train'])
170 contact_matrices_valid = flat_input(contact_matrices['valid'])
171 contact_matrices_test = flat_input(contact_matrices['test'])
173 damping_days = data['damping_day']
174 damping_days_splitted = split_damping_days(damping_days)
175 damping_days_train = damping_days_splitted['train']
176 damping_days_valid = damping_days_splitted['valid']
177 damping_days_test = damping_days_splitted['test']
179 train_inputs = tf.concat(
180 [tf.cast(train_inputs_compartments, tf.float32),
181 tf.cast(contact_matrices_train, tf.float32),
182 tf.cast(damping_days_train, tf.float32)],
183 axis=1, name='concat')
184 valid_inputs = tf.concat(
185 [tf.cast(valid_inputs_compartments, tf.float32),
186 tf.cast(contact_matrices_valid, tf.float32),
187 tf.cast(damping_days_valid, tf.float32)],
188 axis=1, name='concat')
189 test_inputs = tf.concat(
190 [tf.cast(test_inputs_compartments, tf.float32),
191 tf.cast(contact_matrices_test, tf.float32),
192 tf.cast(damping_days_test, tf.float32)],
193 axis=1, name='concat')
195 elif modeltype == 'timeseries':
197 train_inputs_compartments = (data_splitted["train_inputs"])
198 train_labels = (data_splitted["train_labels"])
199 valid_inputs_compartments = (data_splitted["valid_inputs"])
200 valid_labels = (data_splitted["valid_labels"])
201 test_inputs_compartments = (data_splitted["test_inputs"])
202 test_labels = (data_splitted["test_labels"])
204 contact_matrices = split_contact_matrices(
205 tf.stack(data["contact_matrix"]))
206 contact_matrices_train = flat_input(contact_matrices['train'])
207 contact_matrices_valid = flat_input(contact_matrices['valid'])
208 contact_matrices_test = flat_input(contact_matrices['test'])
210 n = np.array(data['damping_day']).shape[0]
211 train_days = data['damping_day'][:int(n*0.7)]
212 valid_days = data['damping_day'][int(n*0.7):int(n*0.9)]
213 test_days = data['damping_day'][int(n*0.9):]
215 # concatenate the compartment data with contact matrices and damping days
216 # to receive complete input data
217 new_contact_train = []
218 for i in contact_matrices_train:
219 new_contact_train.extend([i for j in range(5)])
221 new_contact_train = tf.reshape(
222 tf.stack(new_contact_train),
223 [train_inputs_compartments.shape[0],
224 5, np.asarray(new_contact_train).shape[1]])
226 new_damping_days_train = []
227 for i in train_days:
228 new_damping_days_train.extend([i for j in range(5)])
229 new_damping_days_train = tf.reshape(
230 tf.stack(new_damping_days_train),
231 [train_inputs_compartments.shape[0],
232 5, 1])
234 train_inputs = tf.concat(
235 (tf.cast(train_inputs_compartments, tf.float16),
236 tf.cast(new_contact_train, tf.float16),
237 tf.cast(new_damping_days_train, tf.float16)),
238 axis=2)
240 new_contact_test = []
241 for i in contact_matrices_test:
242 new_contact_test.extend([i for j in range(5)])
244 new_contact_test = tf.reshape(tf.stack(new_contact_test), [
245 contact_matrices_test.shape[0], 5, contact_matrices_test.shape[1]])
247 new_damping_days_test = []
248 for i in test_days:
249 new_damping_days_test.extend([i for j in range(5)])
250 new_damping_days_test = tf.reshape(
251 tf.stack(new_damping_days_test),
252 [test_inputs_compartments.shape[0],
253 5, 1])
255 test_inputs = tf.concat(
256 (tf.cast(test_inputs_compartments, tf.float16),
257 tf.cast(new_contact_test, tf.float16),
258 tf.cast(new_damping_days_test, tf.float16)),
259 axis=2)
261 new_contact_val = []
262 for i in contact_matrices_valid:
263 new_contact_val.extend([i for j in range(5)])
265 new_contact_val = tf.reshape(
266 tf.stack(new_contact_val),
267 [contact_matrices_valid.shape[0],
268 5, contact_matrices_valid.shape[1]])
270 new_damping_days_valid = []
271 for i in valid_days:
272 new_damping_days_valid.extend([i for j in range(5)])
273 new_damping_days_valid = tf.reshape(
274 tf.stack(new_damping_days_valid),
275 [valid_inputs_compartments.shape[0],
276 5, 1])
278 valid_inputs = tf.concat(
279 (tf.cast(valid_inputs_compartments, tf.float16),
280 tf.cast(new_contact_val, tf.float16),
281 tf.cast(new_damping_days_valid, tf.float16)),
282 axis=2)
284 batch_size = 32
286 early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss',
287 patience=early_stop,
288 mode='min')
290 model.compile(
291 loss=tf.keras.losses.MeanAbsolutePercentageError(),
292 optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
293 metrics=[tf.keras.metrics.MeanSquaredError()])
295 history = model.fit(train_inputs, train_labels, epochs=max_epochs,
296 validation_data=(valid_inputs, valid_labels),
297 batch_size=batch_size,
298 callbacks=[early_stopping])
300 if (plot):
301 plot_losses(history)
302 plot_compartment_prediction_model(
303 test_inputs, test_labels, modeltype, model=model,
304 plot_compartment='InfectedSymptoms', max_subplots=3)
305 df = get_test_statistic(test_inputs, test_labels, model)
306 print(df)
307 return history
310def plot_losses(history):
311 """! Plots the losses of the model training.
313 @param history model training history.
315 """
316 plt.plot(history.history['loss'])
317 plt.plot(history.history['val_loss'])
318 plt.title('model loss')
319 plt.ylabel('loss')
320 plt.xlabel('epoch')
321 plt.legend(['train', 'val'], loc='upper left')
322 if os.path.isdir("plots") == False:
323 os.mkdir("plots")
324 plt.savefig('plots/losses_plot.png')
325 plt.show()
328def get_test_statistic(test_inputs, test_labels, model):
329 """! Calculates the mean absolute percentage error based on the test dataset.
331 @param test_inputs inputs from test data.
332 @param test_labels labels (output) from test data.
333 @param model trained model.
335 """
337 pred = model(test_inputs)
338 pred = pred.numpy()
339 test_labels = np.array(test_labels)
341 diff = pred - test_labels
342 relative_err = (abs(diff))/abs(test_labels)
343 # reshape [batch, time, features] -> [features, time * batch]
344 relative_err_transformed = relative_err.transpose(2, 0, 1).reshape(8, -1)
345 relative_err_means_percentage = relative_err_transformed.mean(axis=1) * 100
346 compartments = [str(compartment).split('.')[1]
347 for compartment in InfectionState.values()]
348 compartments = [x for x in compartments if x !=
349 'InfectedNoSymptomsConfirmed' and x != 'InfectedSymptomsConfirmed']
350 mean_percentage = pd.DataFrame(
351 data=relative_err_means_percentage,
352 index=compartments,
353 columns=['Percentage Error'])
355 return mean_percentage
358def split_data(inputs, labels, split_train=0.7,
359 split_valid=0.2, split_test=0.1):
360 """! Split data set in training, validation and testing data sets.
362 @param inputs input dataset
363 @param labels label dataset
364 @param split_train Share of training data sets.
365 @param split_valid Share of validation data sets.
366 @param split_test Share of testing data sets.
367 """
369 if split_train + split_valid + split_test > 1 + 1e-10:
370 raise ValueError(
371 "Summed data set shares are greater than 1. Please adjust the values.")
372 elif inputs.shape[0] != labels.shape[0] or inputs.shape[2] != labels.shape[2]:
373 raise ValueError(
374 "Number of batches or features different for input and labels")
376 n = inputs.shape[0]
377 n_train = int(n * split_train)
378 n_valid = int(n * split_valid)
379 n_test = n - n_train - n_valid
381 inputs_train, inputs_valid, inputs_test = tf.split(
382 inputs, [n_train, n_valid, n_test], 0)
383 labels_train, labels_valid, labels_test = tf.split(
384 labels, [n_train, n_valid, n_test], 0)
386 data = {
387 'train_inputs': inputs_train,
388 'train_labels': labels_train,
389 'valid_inputs': inputs_valid,
390 'valid_labels': labels_valid,
391 'test_inputs': inputs_test,
392 'test_labels': labels_test
393 }
395 return data
398def flat_input(input):
399 """! Flatten input dimension
401 @param input input array
403 """
404 dim = tf.reduce_prod(tf.shape(input)[1:])
405 return tf.reshape(input, [-1, dim])
408def split_contact_matrices(contact_matrices, split_train=0.7,
409 split_valid=0.2, split_test=0.1):
410 """! Split dampings in train, valid and test
412 @param contact_matrices contact matrices
413 @param labels label dataset
414 @param split_train ratio of train datasets
415 @param split_valid ratio of validation datasets
416 @param split_test ratio of test datasets
417 """
419 if split_train + split_valid + split_test != 1:
420 ValueError("summed Split ratios not equal 1! Please adjust the values")
422 n = contact_matrices.shape[0]
423 n_train = int(n * split_train)
424 n_valid = int(n * split_valid)
425 n_test = n - n_train - n_valid
427 contact_matrices_train, contact_matrices_valid, contact_matrices_test = tf.split(
428 contact_matrices, [n_train, n_valid, n_test], 0)
429 data = {
430 "train": contact_matrices_train,
431 "valid": contact_matrices_valid,
432 "test": contact_matrices_test
433 }
435 return data
438def split_damping_days(damping_days, split_train=0.7,
439 split_valid=0.2, split_test=0.1):
440 """! Split damping days in train, valid and test
442 @param damping_days damping days
443 @param split_train ratio of train datasets
444 @param split_valid ratio of validation datasets
445 @param split_test ratio of test datasets
446 """
448 if split_train + split_valid + split_test != 1:
449 ValueError("summed Split ratios not equal 1! Please adjust the values")
450 damping_days = np.asarray(damping_days)
451 n = damping_days.shape[0]
452 n_train = int(n * split_train)
453 n_valid = int(n * split_valid)
454 n_test = n - n_train - n_valid
456 damping_days_train, damping_days_valid, damping_days_test = tf.split(
457 damping_days, [n_train, n_valid, n_test], 0)
458 data = {
459 "train": tf.reshape(damping_days_train, [n_train, 1]),
460 "valid": tf.reshape(damping_days_valid, [n_valid, 1]),
461 "test": tf.reshape(damping_days_test, [n_test, 1])
462 }
464 return data
467def get_input_dim_lstm(path):
468 """! Extract the dimensiond of the input data
470 @param path path to the data
472 """
473 file = open(os.path.join(path, 'data_secir_groups.pickle'), 'rb')
475 data = pickle.load(file)
476 input_dim = data['inputs'].shape[2] + np.asarray(
477 data['contact_matrix']).shape[1] * np.asarray(data['contact_matrix']).shape[2]+1
479 return input_dim
482if __name__ == "__main__":
483 path = os.path.dirname(os.path.realpath(__file__))
484 path_data = os.path.join(os.path.dirname(os.path.realpath(
485 os.path.dirname(os.path.realpath(path)))), 'data')
486 max_epochs = 100
487 label_width = 30
489 input_dim = get_input_dim_lstm(path_data)
491 model = "CNN"
492 if model == "Dense_Single":
493 model = network_architectures.mlp_multi_input_single_output()
494 modeltype = 'classic'
496 elif model == "Dense":
497 model = network_architectures.mlp_multi_input_multi_output(label_width)
498 modeltype = 'classic'
500 elif model == "LSTM":
501 model = network_architectures.lstm_multi_input_multi_output(
502 label_width)
503 modeltype = 'timeseries'
505 elif model == "CNN":
506 model = network_architectures.cnn_multi_input_multi_output(label_width)
507 modeltype = 'timeseries'
509 model_output = network_fit(
510 path_data, model=model, modeltype=modeltype,
511 max_epochs=max_epochs)