Coverage for /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/memilio/surrogatemodel/ode_secir_simple/model.py: 39%

103 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-11-18 12:29 +0000

1############################################################################# 

2# Copyright (C) 2020-2024 MEmilio 

3# 

4# Authors: Agatha Schmidt, Henrik Zunker 

5# 

6# Contact: Martin J. Kuehn <Martin.Kuehn@DLR.de> 

7# 

8# Licensed under the Apache License, Version 2.0 (the "License"); 

9# you may not use this file except in compliance with the License. 

10# You may obtain a copy of the License at 

11# 

12# http://www.apache.org/licenses/LICENSE-2.0 

13# 

14# Unless required by applicable law or agreed to in writing, software 

15# distributed under the License is distributed on an "AS IS" BASIS, 

16# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

17# See the License for the specific language governing permissions and 

18# limitations under the License. 

19############################################################################# 

20import os 

21import pickle 

22 

23import matplotlib.pyplot as plt 

24import numpy as np 

25import pandas as pd 

26import tensorflow as tf 

27 

28from memilio.simulation.osecir import InfectionState 

29from memilio.surrogatemodel.ode_secir_simple import network_architectures 

30 

31 

32def plot_compartment_prediction_model( 

33 inputs, labels, model=None, plot_compartment='InfectedSymptoms', 

34 max_subplots=8): 

35 """! Plot prediction of the model and label for one compartment. 

36 

37 If model is none, we just plot the inputs and labels for the selected compartment without any predictions.  

38 

39 @param inputs test inputs for model prediction.  

40 @param labels test labels.  

41 @param model trained model.  

42 @param plot_col string name of compartment to be plotted.  

43 @param max_subplots Number of the simulation runs to be plotted and compared against.  

44 """ 

45 

46 input_width = inputs.shape[1] 

47 label_width = labels.shape[1] 

48 

49 plt.figure(figsize=(12, 8)) 

50 plot_compartment_index = 0 

51 for compartment in InfectionState.values(): 

52 if compartment.name == plot_compartment: 

53 break 

54 plot_compartment_index += 1 

55 if plot_compartment_index == len(InfectionState.values()): 

56 raise ValueError('Compartment name given could not be found.') 

57 max_n = min(max_subplots, inputs.shape[0]) 

58 

59 for n in range(max_n): 

60 plt.subplot(max_n, 1, n+1) 

61 plt.ylabel(plot_compartment) 

62 

63 input_array = inputs[n].numpy() 

64 label_array = labels[n].numpy() 

65 plt.plot( 

66 np.arange(0, input_width), 

67 input_array[:, plot_compartment_index], 

68 label='Inputs', marker='.', zorder=-10) 

69 plt.scatter( 

70 np.arange(input_width, input_width + label_width), 

71 label_array[:, plot_compartment_index], 

72 edgecolors='k', label='Labels', c='#2ca02c', s=64) 

73 

74 if model is not None: 

75 input_series = tf.expand_dims(inputs[n], axis=0) 

76 pred = model(input_series) 

77 pred = pred.numpy() 

78 plt.scatter(np.arange(input_width, input_width+pred.shape[-2]), 

79 pred[0, :, plot_compartment_index], 

80 marker='X', edgecolors='k', label='Predictions', 

81 c='#ff7f0e', s=64) 

82 

83 plt.xlabel('days') 

84 if os.path.isdir("plots") == False: 

85 os.mkdir("plots") 

86 plt.savefig('plots/evaluation_secir_simple_' + plot_compartment + '.png') 

87 

88 

89def network_fit(path, model, max_epochs=30, early_stop=500, plot=True): 

90 """! Training and evaluation of a given model with mean squared error loss and Adam optimizer using the mean absolute error as a metric. 

91 

92 @param path path of the dataset.  

93 @param model Keras sequential model. 

94 @param max_epochs int maximum number of epochs in training.  

95 @param early_stop Integer that forces an early stop of training if the given number of epochs does not give a significant reduction of validation loss.  

96 

97 """ 

98 

99 if not os.path.isfile(os.path.join(path, 'data_secir_simple.pickle')): 

100 ValueError("no dataset found in path: " + path) 

101 

102 file = open(os.path.join(path, 'data_secir_simple.pickle'), 'rb') 

103 

104 data = pickle.load(file) 

105 data_splitted = split_data(data['inputs'], data['labels']) 

106 

107 train_inputs = data_splitted['train_inputs'] 

108 train_labels = data_splitted['train_labels'] 

109 valid_inputs = data_splitted['valid_inputs'] 

110 valid_labels = data_splitted['valid_labels'] 

111 test_inputs = data_splitted['test_inputs'] 

112 test_labels = data_splitted['test_labels'] 

113 

114 early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', 

115 patience=early_stop, 

116 mode='min') 

117 

118 model.compile( 

119 loss=tf.keras.losses.MeanSquaredError(), 

120 optimizer=tf.keras.optimizers.Adam(), 

121 metrics=[tf.keras.metrics.MeanAbsoluteError()]) 

122 

123 history = model.fit(train_inputs, train_labels, epochs=max_epochs, 

124 validation_data=(valid_inputs, valid_labels), 

125 callbacks=[early_stopping]) 

126 

127 if (plot): 

128 plot_losses(history) 

129 plot_compartment_prediction_model( 

130 test_inputs, test_labels, model=model, 

131 plot_compartment='InfectedSymptoms', max_subplots=3) 

132 df = get_test_statistic(test_inputs, test_labels, model) 

133 print(df) 

134 return history 

135 

136 

137def plot_losses(history): 

138 """! Plots the losses of the model training.  

139 

140 @param history model training history.  

141 

142 """ 

143 plt.plot(history.history['loss']) 

144 plt.plot(history.history['val_loss']) 

145 plt.title('model loss') 

146 plt.ylabel('loss') 

147 plt.xlabel('epoch') 

148 plt.legend(['train', 'val'], loc='upper left') 

149 if os.path.isdir("plots") == False: 

150 os.mkdir("plots") 

151 plt.savefig('plots/losses_plot.png') 

152 plt.show() 

153 

154 

155def get_test_statistic(test_inputs, test_labels, model): 

156 """! Calculates the mean absolute percentage error based on the test dataset.  

157 

158 @param test_inputs inputs from test data. 

159 @param test_labels labels (output) from test data. 

160 @param model trained model.  

161 

162 """ 

163 

164 pred = model(test_inputs) 

165 pred = pred.numpy() 

166 test_labels = np.array(test_labels) 

167 

168 diff = pred - test_labels 

169 relative_err = (abs(diff))/abs(test_labels) 

170 # reshape [batch, time, features] -> [features, time * batch] 

171 relative_err_transformed = relative_err.transpose(2, 0, 1).reshape(8, -1) 

172 relative_err_means_percentage = relative_err_transformed.mean(axis=1) * 100 

173 mean_percentage = pd.DataFrame( 

174 data=relative_err_means_percentage, 

175 index=[str(compartment).split('.')[1] 

176 for compartment in InfectionState.values()], 

177 columns=['Percentage Error']) 

178 

179 return mean_percentage 

180 

181 

182def split_data(inputs, labels, split_train=0.7, 

183 split_valid=0.2, split_test=0.1): 

184 """! Split data set in training, validation and testing data sets. 

185 

186 @param inputs input dataset 

187 @param labels label dataset 

188 @param split_train Share of training data sets. 

189 @param split_valid Share of validation data sets. 

190 @param split_test Share of testing data sets. 

191 """ 

192 

193 if split_train + split_valid + split_test > 1 + 1e-10: 

194 raise ValueError( 

195 "Summed data set shares are greater than 1. Please adjust the values.") 

196 elif inputs.shape[0] != labels.shape[0] or inputs.shape[2] != labels.shape[2]: 

197 raise ValueError( 

198 "Number of batches or features different for input and labels") 

199 

200 n = inputs.shape[0] 

201 n_train = int(n * split_train) 

202 n_valid = int(n * split_valid) 

203 n_test = n - n_train - n_valid 

204 

205 inputs_train, inputs_valid, inputs_test = tf.split( 

206 inputs, [n_train, n_valid, n_test], 0) 

207 labels_train, labels_valid, labels_test = tf.split( 

208 labels, [n_train, n_valid, n_test], 0) 

209 

210 data = { 

211 'train_inputs': inputs_train, 

212 'train_labels': labels_train, 

213 'valid_inputs': inputs_valid, 

214 'valid_labels': labels_valid, 

215 'test_inputs': inputs_test, 

216 'test_labels': labels_test 

217 } 

218 

219 return data 

220 

221 

222if __name__ == "__main__": 

223 path = os.path.dirname(os.path.realpath(__file__)) 

224 path_data = os.path.join(os.path.dirname(os.path.realpath( 

225 os.path.dirname(os.path.realpath(path)))), 'data') 

226 max_epochs = 400 

227 

228 model = "LSTM" 

229 if model == "Dense": 

230 model = network_architectures.mlp_multi_input_single_output() 

231 elif model == "LSTM": 

232 model = network_architectures.lstm_multi_input_multi_output(30) 

233 elif model == "CNN": 

234 model = network_architectures.cnn_multi_input_multi_output(30) 

235 

236 model_output = network_fit( 

237 path_data, model=model, 

238 max_epochs=max_epochs)