Coverage for /opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/memilio/surrogatemodel/ode_secir_groups/model.py: 54%

226 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-11-18 12:29 +0000

1############################################################################# 

2# Copyright (C) 2020-2023 German Aerospace Center (DLR-SC) 

3# 

4# Authors: Agatha Schmidt, Henrik Zunker 

5# 

6# Contact: Martin J. Kuehn <Martin.Kuehn@DLR.de> 

7# 

8# Licensed under the Apache License, Version 2.0 (the "License"); 

9# you may not use this file except in compliance with the License. 

10# You may obtain a copy of the License at 

11# 

12# http://www.apache.org/licenses/LICENSE-2.0 

13# 

14# Unless required by applicable law or agreed to in writing, software 

15# distributed under the License is distributed on an "AS IS" BASIS, 

16# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

17# See the License for the specific language governing permissions and 

18# limitations under the License. 

19############################################################################# 

20from memilio.surrogatemodel.ode_secir_groups import network_architectures 

21from memilio.simulation.osecir import InfectionState 

22import os 

23import pickle 

24 

25import matplotlib.pyplot as plt 

26import numpy as np 

27import pandas as pd 

28import tensorflow as tf 

29 

30 

31def plot_compartment_prediction_model( 

32 inputs, labels, modeltype, model=None, 

33 plot_compartment='InfectedSymptoms', max_subplots=8): 

34 """! Plot prediction of the model and label for one compartment. The average of all age groups is plotted.  

35 

36 If model is none, we just plot the inputs and labels for the selected compartment without any predictions.  

37 

38 @param inputs test inputs for model prediction.  

39 @param labels test labels.  

40 @param modeltype type of model. Can be 'classic' or 'timeseries' 

41 @param model trained model.  

42 @param plot_col string name of compartment to be plotted.  

43 @param max_subplots Number of the simulation runs to be plotted and compared against.  

44 """ 

45 num_groups = 6 

46 num_compartments = 8 

47 if modeltype == 'classic': 

48 # to get the input_width, we first subtract the damping date and the contact matrix entries. 

49 # Next, we divide by the number of age groups * the number of compartments 

50 input_width = int( 

51 (inputs.shape[1] - (1 + num_groups * num_groups)) / (num_groups * num_compartments)) 

52 elif modeltype == 'timeseries': 

53 input_width = int(inputs.shape[1]) 

54 label_width = int(labels.shape[1]) 

55 

56 plt.figure(figsize=(12, 8)) 

57 plot_compartment_index = 0 

58 for compartment in InfectionState.values(): 

59 if compartment.name == plot_compartment: 

60 break 

61 plot_compartment_index += 1 

62 if plot_compartment_index == len(InfectionState.values()): 

63 raise ValueError('Compartment name given could not be found.') 

64 max_n = min(max_subplots, inputs.shape[0]) 

65 

66 for n in range(max_n): 

67 plt.subplot(max_n, 1, n+1) 

68 plt.ylabel(plot_compartment) 

69 

70 input_array = inputs[n].numpy() 

71 label_array = labels[n].numpy() 

72 

73 if modeltype == 'classic': 

74 input_plot = input_array[:( 

75 input_width*num_groups*num_compartments)] 

76 input_plot = input_plot.reshape( 

77 input_width, num_groups*num_compartments) 

78 

79 mean_per_day_input = [] 

80 for i in input_plot: 

81 x = i[plot_compartment_index::8] 

82 mean_per_day_input.append(x.mean()) 

83 

84 plt.plot( 

85 np.arange(0, input_width), 

86 mean_per_day_input, 

87 label='Inputs', marker='.', zorder=-10) 

88 

89 elif modeltype == 'timeseries': 

90 mean_per_day_input = [] 

91 for i in input_array: 

92 # Inputs has the dimensions [num_runs, input_width, features]. 

93 # The features consist of the compartment data, contact matrices and the damping day. 

94 # Here, we want to get the mean of the plot_compartment over all age groups. Therefore, 

95 # we subtract the damping day and the contact matrix entries. 

96 x = i[plot_compartment_index: inputs.shape[2] - 

97 (1 + num_groups * num_groups):8] 

98 mean_per_day_input.append(x.mean()) 

99 

100 plt.plot( 

101 np.arange(0, input_width), 

102 mean_per_day_input, 

103 label='Inputs', marker='.', zorder=-10) 

104 

105 mean_per_day = [] 

106 for i in label_array: 

107 x = i[plot_compartment_index::8] 

108 mean_per_day.append(x.mean()) 

109 plt.scatter( 

110 np.arange(input_width, input_width + label_width), 

111 mean_per_day, 

112 edgecolors='k', label='Labels', c='#2ca02c', s=64) 

113 

114 if model is not None: 

115 input_series = tf.expand_dims(inputs[n], axis=0) 

116 pred = model(input_series) 

117 pred = pred.numpy() 

118 pred = pred.reshape((30, 48)) 

119 

120 mean_per_day_pred = [] 

121 for i in pred: 

122 x = i[plot_compartment_index::8] 

123 mean_per_day_pred.append(x.mean()) 

124 

125 plt.scatter(np.arange(input_width, input_width+pred.shape[-2]), 

126 # pred[0, :, plot_compartment_index], 

127 mean_per_day_pred, 

128 marker='X', edgecolors='k', label='Predictions', 

129 c='#ff7f0e', s=64) 

130 

131 plt.xlabel('days') 

132 plt.legend() 

133 if os.path.isdir("plots") == False: 

134 os.mkdir("plots") 

135 plt.savefig('plots/evaluation_secir_groups_' + plot_compartment + '.png') 

136 

137 

138def network_fit( 

139 path, model, modeltype, max_epochs=30, early_stop=500, plot=True): 

140 """! Training and evaluation of a given model with mean squared error loss and Adam optimizer using the mean absolute error as a metric. 

141 

142 @param path path of the dataset.  

143 @param model Keras sequential model. 

144 @param modeltype type of model. Can be 'classic' or 'timeseries'. Data preparation is made based on the modeltype. 

145 @param max_epochs int maximum number of epochs in training.  

146 @param early_stop Integer that forces an early stop of training if the given number of epochs does not give a significant reduction of validation loss.  

147 

148 """ 

149 

150 if not os.path.isfile(os.path.join(path, 'data_secir_groups.pickle')): 

151 ValueError("no dataset found in path: " + path) 

152 

153 file = open(os.path.join(path, 'data_secir_groups.pickle'), 'rb') 

154 

155 data = pickle.load(file) 

156 data_splitted = split_data(data['inputs'], data['labels']) 

157 

158 if modeltype == 'classic': 

159 

160 train_inputs_compartments = flat_input(data_splitted["train_inputs"]) 

161 train_labels = (data_splitted["train_labels"]) 

162 valid_inputs_compartments = flat_input(data_splitted["valid_inputs"]) 

163 valid_labels = (data_splitted["valid_labels"]) 

164 test_inputs_compartments = flat_input(data_splitted["test_inputs"]) 

165 test_labels = (data_splitted["test_labels"]) 

166 

167 contact_matrices = split_contact_matrices( 

168 tf.stack(data["contact_matrix"])) 

169 contact_matrices_train = flat_input(contact_matrices['train']) 

170 contact_matrices_valid = flat_input(contact_matrices['valid']) 

171 contact_matrices_test = flat_input(contact_matrices['test']) 

172 

173 damping_days = data['damping_day'] 

174 damping_days_splitted = split_damping_days(damping_days) 

175 damping_days_train = damping_days_splitted['train'] 

176 damping_days_valid = damping_days_splitted['valid'] 

177 damping_days_test = damping_days_splitted['test'] 

178 

179 train_inputs = tf.concat( 

180 [tf.cast(train_inputs_compartments, tf.float32), 

181 tf.cast(contact_matrices_train, tf.float32), 

182 tf.cast(damping_days_train, tf.float32)], 

183 axis=1, name='concat') 

184 valid_inputs = tf.concat( 

185 [tf.cast(valid_inputs_compartments, tf.float32), 

186 tf.cast(contact_matrices_valid, tf.float32), 

187 tf.cast(damping_days_valid, tf.float32)], 

188 axis=1, name='concat') 

189 test_inputs = tf.concat( 

190 [tf.cast(test_inputs_compartments, tf.float32), 

191 tf.cast(contact_matrices_test, tf.float32), 

192 tf.cast(damping_days_test, tf.float32)], 

193 axis=1, name='concat') 

194 

195 elif modeltype == 'timeseries': 

196 

197 train_inputs_compartments = (data_splitted["train_inputs"]) 

198 train_labels = (data_splitted["train_labels"]) 

199 valid_inputs_compartments = (data_splitted["valid_inputs"]) 

200 valid_labels = (data_splitted["valid_labels"]) 

201 test_inputs_compartments = (data_splitted["test_inputs"]) 

202 test_labels = (data_splitted["test_labels"]) 

203 

204 contact_matrices = split_contact_matrices( 

205 tf.stack(data["contact_matrix"])) 

206 contact_matrices_train = flat_input(contact_matrices['train']) 

207 contact_matrices_valid = flat_input(contact_matrices['valid']) 

208 contact_matrices_test = flat_input(contact_matrices['test']) 

209 

210 n = np.array(data['damping_day']).shape[0] 

211 train_days = data['damping_day'][:int(n*0.7)] 

212 valid_days = data['damping_day'][int(n*0.7):int(n*0.9)] 

213 test_days = data['damping_day'][int(n*0.9):] 

214 

215 # concatenate the compartment data with contact matrices and damping days 

216 # to receive complete input data 

217 new_contact_train = [] 

218 for i in contact_matrices_train: 

219 new_contact_train.extend([i for j in range(5)]) 

220 

221 new_contact_train = tf.reshape( 

222 tf.stack(new_contact_train), 

223 [train_inputs_compartments.shape[0], 

224 5, np.asarray(new_contact_train).shape[1]]) 

225 

226 new_damping_days_train = [] 

227 for i in train_days: 

228 new_damping_days_train.extend([i for j in range(5)]) 

229 new_damping_days_train = tf.reshape( 

230 tf.stack(new_damping_days_train), 

231 [train_inputs_compartments.shape[0], 

232 5, 1]) 

233 

234 train_inputs = tf.concat( 

235 (tf.cast(train_inputs_compartments, tf.float16), 

236 tf.cast(new_contact_train, tf.float16), 

237 tf.cast(new_damping_days_train, tf.float16)), 

238 axis=2) 

239 

240 new_contact_test = [] 

241 for i in contact_matrices_test: 

242 new_contact_test.extend([i for j in range(5)]) 

243 

244 new_contact_test = tf.reshape(tf.stack(new_contact_test), [ 

245 contact_matrices_test.shape[0], 5, contact_matrices_test.shape[1]]) 

246 

247 new_damping_days_test = [] 

248 for i in test_days: 

249 new_damping_days_test.extend([i for j in range(5)]) 

250 new_damping_days_test = tf.reshape( 

251 tf.stack(new_damping_days_test), 

252 [test_inputs_compartments.shape[0], 

253 5, 1]) 

254 

255 test_inputs = tf.concat( 

256 (tf.cast(test_inputs_compartments, tf.float16), 

257 tf.cast(new_contact_test, tf.float16), 

258 tf.cast(new_damping_days_test, tf.float16)), 

259 axis=2) 

260 

261 new_contact_val = [] 

262 for i in contact_matrices_valid: 

263 new_contact_val.extend([i for j in range(5)]) 

264 

265 new_contact_val = tf.reshape( 

266 tf.stack(new_contact_val), 

267 [contact_matrices_valid.shape[0], 

268 5, contact_matrices_valid.shape[1]]) 

269 

270 new_damping_days_valid = [] 

271 for i in valid_days: 

272 new_damping_days_valid.extend([i for j in range(5)]) 

273 new_damping_days_valid = tf.reshape( 

274 tf.stack(new_damping_days_valid), 

275 [valid_inputs_compartments.shape[0], 

276 5, 1]) 

277 

278 valid_inputs = tf.concat( 

279 (tf.cast(valid_inputs_compartments, tf.float16), 

280 tf.cast(new_contact_val, tf.float16), 

281 tf.cast(new_damping_days_valid, tf.float16)), 

282 axis=2) 

283 

284 batch_size = 32 

285 

286 early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', 

287 patience=early_stop, 

288 mode='min') 

289 

290 model.compile( 

291 loss=tf.keras.losses.MeanAbsolutePercentageError(), 

292 optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), 

293 metrics=[tf.keras.metrics.MeanSquaredError()]) 

294 

295 history = model.fit(train_inputs, train_labels, epochs=max_epochs, 

296 validation_data=(valid_inputs, valid_labels), 

297 batch_size=batch_size, 

298 callbacks=[early_stopping]) 

299 

300 if (plot): 

301 plot_losses(history) 

302 plot_compartment_prediction_model( 

303 test_inputs, test_labels, modeltype, model=model, 

304 plot_compartment='InfectedSymptoms', max_subplots=3) 

305 df = get_test_statistic(test_inputs, test_labels, model) 

306 print(df) 

307 return history 

308 

309 

310def plot_losses(history): 

311 """! Plots the losses of the model training.  

312 

313 @param history model training history.  

314 

315 """ 

316 plt.plot(history.history['loss']) 

317 plt.plot(history.history['val_loss']) 

318 plt.title('model loss') 

319 plt.ylabel('loss') 

320 plt.xlabel('epoch') 

321 plt.legend(['train', 'val'], loc='upper left') 

322 if os.path.isdir("plots") == False: 

323 os.mkdir("plots") 

324 plt.savefig('plots/losses_plot.png') 

325 plt.show() 

326 

327 

328def get_test_statistic(test_inputs, test_labels, model): 

329 """! Calculates the mean absolute percentage error based on the test dataset.  

330 

331 @param test_inputs inputs from test data. 

332 @param test_labels labels (output) from test data. 

333 @param model trained model.  

334 

335 """ 

336 

337 pred = model(test_inputs) 

338 pred = pred.numpy() 

339 test_labels = np.array(test_labels) 

340 

341 diff = pred - test_labels 

342 relative_err = (abs(diff))/abs(test_labels) 

343 # reshape [batch, time, features] -> [features, time * batch] 

344 relative_err_transformed = relative_err.transpose(2, 0, 1).reshape(8, -1) 

345 relative_err_means_percentage = relative_err_transformed.mean(axis=1) * 100 

346 compartments = [str(compartment).split('.')[1] 

347 for compartment in InfectionState.values()] 

348 compartments = [x for x in compartments if x != 

349 'InfectedNoSymptomsConfirmed' and x != 'InfectedSymptomsConfirmed'] 

350 mean_percentage = pd.DataFrame( 

351 data=relative_err_means_percentage, 

352 index=compartments, 

353 columns=['Percentage Error']) 

354 

355 return mean_percentage 

356 

357 

358def split_data(inputs, labels, split_train=0.7, 

359 split_valid=0.2, split_test=0.1): 

360 """! Split data set in training, validation and testing data sets. 

361 

362 @param inputs input dataset 

363 @param labels label dataset 

364 @param split_train Share of training data sets. 

365 @param split_valid Share of validation data sets. 

366 @param split_test Share of testing data sets. 

367 """ 

368 

369 if split_train + split_valid + split_test > 1 + 1e-10: 

370 raise ValueError( 

371 "Summed data set shares are greater than 1. Please adjust the values.") 

372 elif inputs.shape[0] != labels.shape[0] or inputs.shape[2] != labels.shape[2]: 

373 raise ValueError( 

374 "Number of batches or features different for input and labels") 

375 

376 n = inputs.shape[0] 

377 n_train = int(n * split_train) 

378 n_valid = int(n * split_valid) 

379 n_test = n - n_train - n_valid 

380 

381 inputs_train, inputs_valid, inputs_test = tf.split( 

382 inputs, [n_train, n_valid, n_test], 0) 

383 labels_train, labels_valid, labels_test = tf.split( 

384 labels, [n_train, n_valid, n_test], 0) 

385 

386 data = { 

387 'train_inputs': inputs_train, 

388 'train_labels': labels_train, 

389 'valid_inputs': inputs_valid, 

390 'valid_labels': labels_valid, 

391 'test_inputs': inputs_test, 

392 'test_labels': labels_test 

393 } 

394 

395 return data 

396 

397 

398def flat_input(input): 

399 """! Flatten input dimension 

400 

401 @param input input array 

402 

403 """ 

404 dim = tf.reduce_prod(tf.shape(input)[1:]) 

405 return tf.reshape(input, [-1, dim]) 

406 

407 

408def split_contact_matrices(contact_matrices, split_train=0.7, 

409 split_valid=0.2, split_test=0.1): 

410 """! Split dampings in train, valid and test 

411 

412 @param contact_matrices contact matrices 

413 @param labels label dataset 

414 @param split_train ratio of train datasets 

415 @param split_valid ratio of validation datasets 

416 @param split_test ratio of test datasets 

417 """ 

418 

419 if split_train + split_valid + split_test != 1: 

420 ValueError("summed Split ratios not equal 1! Please adjust the values") 

421 

422 n = contact_matrices.shape[0] 

423 n_train = int(n * split_train) 

424 n_valid = int(n * split_valid) 

425 n_test = n - n_train - n_valid 

426 

427 contact_matrices_train, contact_matrices_valid, contact_matrices_test = tf.split( 

428 contact_matrices, [n_train, n_valid, n_test], 0) 

429 data = { 

430 "train": contact_matrices_train, 

431 "valid": contact_matrices_valid, 

432 "test": contact_matrices_test 

433 } 

434 

435 return data 

436 

437 

438def split_damping_days(damping_days, split_train=0.7, 

439 split_valid=0.2, split_test=0.1): 

440 """! Split damping days in train, valid and test 

441 

442 @param damping_days damping days 

443 @param split_train ratio of train datasets 

444 @param split_valid ratio of validation datasets 

445 @param split_test ratio of test datasets 

446 """ 

447 

448 if split_train + split_valid + split_test != 1: 

449 ValueError("summed Split ratios not equal 1! Please adjust the values") 

450 damping_days = np.asarray(damping_days) 

451 n = damping_days.shape[0] 

452 n_train = int(n * split_train) 

453 n_valid = int(n * split_valid) 

454 n_test = n - n_train - n_valid 

455 

456 damping_days_train, damping_days_valid, damping_days_test = tf.split( 

457 damping_days, [n_train, n_valid, n_test], 0) 

458 data = { 

459 "train": tf.reshape(damping_days_train, [n_train, 1]), 

460 "valid": tf.reshape(damping_days_valid, [n_valid, 1]), 

461 "test": tf.reshape(damping_days_test, [n_test, 1]) 

462 } 

463 

464 return data 

465 

466 

467def get_input_dim_lstm(path): 

468 """! Extract the dimensiond of the input data 

469 

470 @param path path to the data  

471 

472 """ 

473 file = open(os.path.join(path, 'data_secir_groups.pickle'), 'rb') 

474 

475 data = pickle.load(file) 

476 input_dim = data['inputs'].shape[2] + np.asarray( 

477 data['contact_matrix']).shape[1] * np.asarray(data['contact_matrix']).shape[2]+1 

478 

479 return input_dim 

480 

481 

482if __name__ == "__main__": 

483 path = os.path.dirname(os.path.realpath(__file__)) 

484 path_data = os.path.join(os.path.dirname(os.path.realpath( 

485 os.path.dirname(os.path.realpath(path)))), 'data') 

486 max_epochs = 100 

487 label_width = 30 

488 

489 input_dim = get_input_dim_lstm(path_data) 

490 

491 model = "CNN" 

492 if model == "Dense_Single": 

493 model = network_architectures.mlp_multi_input_single_output() 

494 modeltype = 'classic' 

495 

496 elif model == "Dense": 

497 model = network_architectures.mlp_multi_input_multi_output(label_width) 

498 modeltype = 'classic' 

499 

500 elif model == "LSTM": 

501 model = network_architectures.lstm_multi_input_multi_output( 

502 label_width) 

503 modeltype = 'timeseries' 

504 

505 elif model == "CNN": 

506 model = network_architectures.cnn_multi_input_multi_output(label_width) 

507 modeltype = 'timeseries' 

508 

509 model_output = network_fit( 

510 path_data, model=model, modeltype=modeltype, 

511 max_epochs=max_epochs)