LCOV - code coverage report
Current view: top level - models/ode_secirts - parameter_space.h (source / functions) Hit Total Coverage
Test: coverage.info Lines: 118 120 98.3 %
Date: 2025-01-17 12:16:22 Functions: 7 7 100.0 %

          Line data    Source code
       1             : /* 
       2             : * Copyright (C) 2020-2025 MEmilio
       3             : *
       4             : * Authors: Henrik Zunker, Wadim Koslow, Daniel Abele, Martin J. Kühn
       5             : *
       6             : * Contact: Martin J. Kuehn <Martin.Kuehn@DLR.de>
       7             : *
       8             : * Licensed under the Apache License, Version 2.0 (the "License");
       9             : * you may not use this file except in compliance with the License.
      10             : * You may obtain a copy of the License at
      11             : *
      12             : *     http://www.apache.org/licenses/LICENSE-2.0
      13             : *
      14             : * Unless required by applicable law or agreed to in writing, software
      15             : * distributed under the License is distributed on an "AS IS" BASIS,
      16             : * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      17             : * See the License for the specific language governing permissions and
      18             : * limitations under the License.
      19             : */
      20             : #ifndef MIO_ODE_SECIRTS_PARAMETER_SPACE_H
      21             : #define MIO_ODE_SECIRTS_PARAMETER_SPACE_H
      22             : 
      23             : #include "memilio/mobility/metapopulation_mobility_instant.h"
      24             : #include "memilio/utils/memory.h"
      25             : #include "memilio/utils/logging.h"
      26             : #include "memilio/utils/parameter_distributions.h"
      27             : #include "ode_secirts/model.h"
      28             : 
      29             : #include <assert.h>
      30             : #include <string>
      31             : #include <vector>
      32             : #include <random>
      33             : #include <memory>
      34             : 
      35             : namespace mio
      36             : {
      37             : namespace osecirts
      38             : {
      39             : /**
      40             :  * Draws a sample from the specified distributions for all parameters
      41             :  * related to the demographics, e.g., population.
      42             :  * @tparam FP Floating point type, e.g., double.
      43             :  * @param[inout] model Model including contact patterns for all age groups
      44             :  */
      45             : template <typename FP = double>
      46         117 : void draw_sample_demographics(Model<FP>& model)
      47             : {
      48         117 :     model.parameters.template get<ICUCapacity<FP>>().draw_sample();
      49         117 :     model.parameters.template get<TestAndTraceCapacity<FP>>().draw_sample();
      50             : 
      51         117 :     const static std::vector<InfectionState> naive_immunity_states = {
      52             :         InfectionState::SusceptibleNaive,
      53             :         InfectionState::ExposedNaive,
      54             :         InfectionState::InfectedNoSymptomsNaive,
      55             :         InfectionState::InfectedNoSymptomsNaiveConfirmed,
      56             :         InfectionState::InfectedSymptomsNaive,
      57             :         InfectionState::InfectedSymptomsNaiveConfirmed,
      58             :         InfectionState::InfectedSevereNaive,
      59             :         InfectionState::InfectedCriticalNaive,
      60             :         InfectionState::DeadNaive,
      61             :     };
      62             : 
      63         117 :     const static std::vector<InfectionState> partial_immunity_states = {
      64             :         InfectionState::SusceptiblePartialImmunity,        InfectionState::ExposedPartialImmunity,
      65             :         InfectionState::InfectedNoSymptomsPartialImmunity, InfectionState::InfectedNoSymptomsPartialImmunityConfirmed,
      66             :         InfectionState::InfectedSymptomsPartialImmunity,   InfectionState::InfectedSymptomsPartialImmunityConfirmed,
      67             :         InfectionState::InfectedSeverePartialImmunity,     InfectionState::InfectedCriticalPartialImmunity,
      68             :         InfectionState::TemporaryImmunePartialImmunity,    InfectionState::DeadPartialImmunity,
      69             :     };
      70             : 
      71         117 :     const static std::vector<InfectionState> improved_immunity_states = {
      72             :         InfectionState::SusceptibleImprovedImmunity,        InfectionState::ExposedImprovedImmunity,
      73             :         InfectionState::InfectedNoSymptomsImprovedImmunity, InfectionState::InfectedNoSymptomsImprovedImmunityConfirmed,
      74             :         InfectionState::InfectedSymptomsImprovedImmunity,   InfectionState::InfectedSymptomsImprovedImmunityConfirmed,
      75             :         InfectionState::InfectedSevereImprovedImmunity,     InfectionState::InfectedCriticalImprovedImmunity,
      76             :         InfectionState::TemporaryImmuneImprovedImmunity,    InfectionState::DeadImprovedImmunity,
      77             :     };
      78             : 
      79             :     // helper function to calculate the total population of a layer for a given age group
      80        3789 :     auto calculate_layer_total = [&model](const std::vector<InfectionState>& states, AgeGroup ageGroup) {
      81        3672 :         return std::accumulate(states.begin(), states.end(), 0.0,
      82       70992 :                                [&model, &ageGroup](double sum, const InfectionState& state) {
      83       70992 :                                    return sum + model.populations[{ageGroup, state}];
      84        3672 :                                });
      85             :     };
      86             : 
      87             :     // helper function to adjust the susceptible population of a layer for a given age group
      88        5625 :     auto adjust_susceptible_population = [&model](AgeGroup i, double diff, InfectionState susceptibleState) {
      89        1836 :         model.populations[{i, susceptibleState}] += diff;
      90        1836 :         if (model.populations[{i, susceptibleState}] < 0) {
      91           0 :             mio::log_warning("Negative population in State " + std::to_string(static_cast<size_t>(susceptibleState)) +
      92             :                              " for age group " + std::to_string(static_cast<size_t>(i)) + ". Setting to 0.");
      93           0 :             model.populations[{i, susceptibleState}] = 0;
      94             :         }
      95             :     };
      96             : 
      97         729 :     for (auto i = AgeGroup(0); i < model.parameters.get_num_groups(); i++) {
      98             : 
      99         612 :         const double group_naive_total    = calculate_layer_total(naive_immunity_states, i);
     100         612 :         const double group_partial_total  = calculate_layer_total(partial_immunity_states, i);
     101         612 :         const double group_improved_total = calculate_layer_total(improved_immunity_states, i);
     102             : 
     103             :         //sample initial compartments (with exceptions)
     104       18360 :         for (auto inf_state = Index<InfectionState>(0); inf_state < InfectionState::Count; ++inf_state) {
     105       35496 :             if (inf_state != InfectionState::DeadNaive && //not sampled, fixed from data
     106       69768 :                 inf_state != InfectionState::DeadPartialImmunity && //not sampled, fixed from data
     107       34272 :                 inf_state != InfectionState::DeadImprovedImmunity) { //not sampled, fixed from data
     108       15912 :                 model.populations[{i, inf_state}].draw_sample();
     109             :             }
     110             :         }
     111         612 :         const double diff_naive    = group_naive_total - calculate_layer_total(naive_immunity_states, i);
     112         612 :         const double diff_partial  = group_partial_total - calculate_layer_total(partial_immunity_states, i);
     113         612 :         const double diff_improved = group_improved_total - calculate_layer_total(improved_immunity_states, i);
     114             : 
     115         612 :         adjust_susceptible_population(i, diff_naive, InfectionState::SusceptibleNaive);
     116         612 :         adjust_susceptible_population(i, diff_partial, InfectionState::SusceptiblePartialImmunity);
     117         612 :         adjust_susceptible_population(i, diff_improved, InfectionState::SusceptibleImprovedImmunity);
     118             :     }
     119         117 : }
     120             : 
     121             : /**
     122             :  * Draws a sample from the specified distributions for all parameters 
     123             :  * related to the infection.
     124             :  * 
     125             :  * @tparam FP Floating point type, e.g., double.
     126             :  * @param[inout] model Model including contact patterns for all age groups.
     127             :  */
     128             : template <typename FP = double>
     129         108 : void draw_sample_infection(Model<FP>& model)
     130             : {
     131         108 :     model.parameters.template get<Seasonality<FP>>().draw_sample();
     132             : 
     133             :     //not age dependent
     134         108 :     model.parameters.template get<TimeExposed<FP>>()[AgeGroup(0)].draw_sample();
     135         108 :     model.parameters.template get<TimeInfectedNoSymptoms<FP>>()[AgeGroup(0)].draw_sample();
     136         108 :     model.parameters.template get<RelativeTransmissionNoSymptoms<FP>>()[AgeGroup(0)].draw_sample();
     137         108 :     model.parameters.template get<RiskOfInfectionFromSymptomatic<FP>>()[AgeGroup(0)].draw_sample();
     138         108 :     model.parameters.template get<MaxRiskOfInfectionFromSymptomatic<FP>>()[AgeGroup(0)].draw_sample();
     139         108 :     model.parameters.template get<TimeTemporaryImmunityPI<FP>>()[AgeGroup(0)].draw_sample();
     140         108 :     model.parameters.template get<TimeTemporaryImmunityII<FP>>()[AgeGroup(0)].draw_sample();
     141             : 
     142         108 :     model.parameters.template get<ReducExposedPartialImmunity<FP>>()[AgeGroup(0)].draw_sample();
     143         108 :     model.parameters.template get<ReducExposedImprovedImmunity<FP>>()[AgeGroup(0)].draw_sample();
     144         108 :     model.parameters.template get<ReducInfectedSymptomsPartialImmunity<FP>>()[AgeGroup(0)].draw_sample();
     145         108 :     model.parameters.template get<ReducInfectedSymptomsImprovedImmunity<FP>>()[AgeGroup(0)].draw_sample();
     146         108 :     model.parameters.template get<ReducInfectedSevereCriticalDeadPartialImmunity<FP>>()[AgeGroup(0)].draw_sample();
     147         108 :     model.parameters.template get<ReducInfectedSevereCriticalDeadImprovedImmunity<FP>>()[AgeGroup(0)].draw_sample();
     148         108 :     model.parameters.template get<ReducTimeInfectedMild<FP>>()[AgeGroup(0)].draw_sample();
     149             : 
     150         666 :     for (auto i = AgeGroup(0); i < model.parameters.get_num_groups(); i++) {
     151             :         //not age dependent
     152         558 :         model.parameters.template get<TimeExposed<FP>>()[i] =
     153        1116 :             model.parameters.template get<TimeExposed<FP>>()[AgeGroup(0)];
     154         558 :         model.parameters.template get<TimeInfectedNoSymptoms<FP>>()[i] =
     155        1116 :             model.parameters.template get<TimeInfectedNoSymptoms<FP>>()[AgeGroup(0)];
     156         558 :         model.parameters.template get<RelativeTransmissionNoSymptoms<FP>>()[i] =
     157        1116 :             model.parameters.template get<RelativeTransmissionNoSymptoms<FP>>()[AgeGroup(0)];
     158         558 :         model.parameters.template get<RiskOfInfectionFromSymptomatic<FP>>()[i] =
     159        1116 :             model.parameters.template get<RiskOfInfectionFromSymptomatic<FP>>()[AgeGroup(0)];
     160         558 :         model.parameters.template get<MaxRiskOfInfectionFromSymptomatic<FP>>()[i] =
     161        1116 :             model.parameters.template get<MaxRiskOfInfectionFromSymptomatic<FP>>()[AgeGroup(0)];
     162             : 
     163         558 :         model.parameters.template get<ReducExposedPartialImmunity<FP>>()[i] =
     164        1116 :             model.parameters.template get<ReducExposedPartialImmunity<FP>>()[AgeGroup(0)];
     165         558 :         model.parameters.template get<ReducExposedImprovedImmunity<FP>>()[i] =
     166        1116 :             model.parameters.template get<ReducExposedImprovedImmunity<FP>>()[AgeGroup(0)];
     167         558 :         model.parameters.template get<ReducInfectedSymptomsPartialImmunity<FP>>()[i] =
     168        1116 :             model.parameters.template get<ReducInfectedSymptomsPartialImmunity<FP>>()[AgeGroup(0)];
     169         558 :         model.parameters.template get<ReducInfectedSymptomsImprovedImmunity<FP>>()[i] =
     170        1116 :             model.parameters.template get<ReducInfectedSymptomsImprovedImmunity<FP>>()[AgeGroup(0)];
     171         558 :         model.parameters.template get<ReducInfectedSevereCriticalDeadPartialImmunity<FP>>()[i] =
     172        1116 :             model.parameters.template get<ReducInfectedSevereCriticalDeadPartialImmunity<FP>>()[AgeGroup(0)];
     173         558 :         model.parameters.template get<ReducInfectedSevereCriticalDeadImprovedImmunity<FP>>()[i] =
     174        1116 :             model.parameters.template get<ReducInfectedSevereCriticalDeadImprovedImmunity<FP>>()[AgeGroup(0)];
     175         558 :         model.parameters.template get<ReducTimeInfectedMild<FP>>()[i] =
     176        1116 :             model.parameters.template get<ReducTimeInfectedMild<FP>>()[AgeGroup(0)];
     177             : 
     178             :         //age dependent
     179         558 :         model.parameters.template get<TimeInfectedSymptoms<FP>>()[i].draw_sample();
     180         558 :         model.parameters.template get<TimeInfectedSevere<FP>>()[i].draw_sample();
     181         558 :         model.parameters.template get<TimeInfectedCritical<FP>>()[i].draw_sample();
     182             : 
     183         558 :         model.parameters.template get<TransmissionProbabilityOnContact<FP>>()[i].draw_sample();
     184         558 :         model.parameters.template get<RecoveredPerInfectedNoSymptoms<FP>>()[i].draw_sample();
     185         558 :         model.parameters.template get<DeathsPerCritical<FP>>()[i].draw_sample();
     186         558 :         model.parameters.template get<SeverePerInfectedSymptoms<FP>>()[i].draw_sample();
     187         558 :         model.parameters.template get<CriticalPerSevere<FP>>()[i].draw_sample();
     188             :     }
     189         108 : }
     190             : 
     191             : /**
     192             :  * Draws a sample from model parameter distributions and stores sample values
     193             :  * as parameters values (cf. UncertainValue and Parameters classes).
     194             :  * 
     195             :  * @tparam FP Floating point type, e.g., double.
     196             :  * @param[inout] model Model including contact patterns for all age groups.
     197             :  */
     198             : template <typename FP = double>
     199           9 : void draw_sample(Model<FP>& model)
     200             : {
     201           9 :     draw_sample_infection(model);
     202           9 :     draw_sample_demographics(model);
     203           9 :     model.parameters.template get<ContactPatterns<FP>>().draw_sample();
     204           9 :     model.apply_constraints();
     205           9 : }
     206             : 
     207             : /**
     208             :  * Draws samples for each model node in a graph.
     209             :  * Some parameters are shared between nodes and are only sampled once.
     210             :  * 
     211             :  * @tparam FP Floating point type, e.g., double.
     212             :  * @param graph Graph to be sampled.
     213             :  * @return Graph with nodes and edges from the input graph sampled.
     214             :  */
     215             : template <typename FP = double>
     216          99 : Graph<Model<FP>, MobilityParameters<FP>> draw_sample(Graph<Model<FP>, MobilityParameters<FP>>& graph)
     217             : {
     218          99 :     Graph<Model<FP>, MobilityParameters<FP>> sampled_graph;
     219             : 
     220             :     //sample global parameters
     221          99 :     auto& shared_params_model = graph.nodes()[0].property;
     222          99 :     draw_sample_infection(shared_params_model);
     223          99 :     auto& shared_contacts = shared_params_model.parameters.template get<ContactPatterns<FP>>();
     224          99 :     shared_contacts.draw_sample_dampings();
     225          99 :     auto& shared_dynamic_npis = shared_params_model.parameters.template get<DynamicNPIsInfectedSymptoms<FP>>();
     226          99 :     shared_dynamic_npis.draw_sample();
     227             : 
     228         315 :     for (auto& params_node : graph.nodes()) {
     229         108 :         auto& node_model = params_node.property;
     230             : 
     231             :         //sample local parameters
     232         108 :         draw_sample_demographics(params_node.property);
     233             : 
     234             :         //copy global parameters
     235             :         //save demographic parameters so they aren't overwritten
     236         108 :         auto local_icu_capacity = node_model.parameters.template get<ICUCapacity<FP>>();
     237         108 :         auto local_tnt_capacity = node_model.parameters.template get<TestAndTraceCapacity<FP>>();
     238         108 :         auto local_holidays     = node_model.parameters.template get<ContactPatterns<FP>>().get_school_holidays();
     239         108 :         auto local_daily_v1     = node_model.parameters.template get<DailyPartialVaccinations<FP>>();
     240         108 :         auto local_daily_v2     = node_model.parameters.template get<DailyFullVaccinations<FP>>();
     241         108 :         auto local_daily_v3     = node_model.parameters.template get<DailyBoosterVaccinations<FP>>();
     242         108 :         node_model.parameters   = shared_params_model.parameters;
     243         108 :         node_model.parameters.template get<ICUCapacity<FP>>()                           = local_icu_capacity;
     244         108 :         node_model.parameters.template get<TestAndTraceCapacity<FP>>()                  = local_tnt_capacity;
     245         108 :         node_model.parameters.template get<ContactPatterns<FP>>().get_school_holidays() = local_holidays;
     246         108 :         node_model.parameters.template get<DailyPartialVaccinations<FP>>()              = local_daily_v1;
     247         108 :         node_model.parameters.template get<DailyFullVaccinations<FP>>()                 = local_daily_v2;
     248         108 :         node_model.parameters.template get<DailyBoosterVaccinations<FP>>()              = local_daily_v3;
     249             : 
     250         108 :         node_model.parameters.template get<ContactPatterns<FP>>().make_matrix();
     251         108 :         node_model.apply_constraints();
     252             : 
     253         108 :         sampled_graph.add_node(params_node.id, node_model);
     254             :     }
     255             : 
     256         117 :     for (auto& edge : graph.edges()) {
     257           9 :         auto edge_params = edge.property;
     258             :         //no dynamic NPIs
     259             :         //TODO: add switch to optionally enable dynamic NPIs to edges
     260           9 :         sampled_graph.add_edge(edge.start_node_idx, edge.end_node_idx, edge_params);
     261             :     }
     262             : 
     263         198 :     return sampled_graph;
     264          99 : }
     265             : 
     266             : } // namespace osecirts
     267             : } // namespace mio
     268             : 
     269             : #endif // MIO_ODE_SECIRTS_PARAMETER_SPACE_H

Generated by: LCOV version 1.14