Line data Source code
1 : /*
2 : * Copyright (C) 2020-2025 MEmilio
3 : *
4 : * Authors: Henrik Zunker, Wadim Koslow, Daniel Abele, Martin J. Kühn
5 : *
6 : * Contact: Martin J. Kuehn <Martin.Kuehn@DLR.de>
7 : *
8 : * Licensed under the Apache License, Version 2.0 (the "License");
9 : * you may not use this file except in compliance with the License.
10 : * You may obtain a copy of the License at
11 : *
12 : * http://www.apache.org/licenses/LICENSE-2.0
13 : *
14 : * Unless required by applicable law or agreed to in writing, software
15 : * distributed under the License is distributed on an "AS IS" BASIS,
16 : * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 : * See the License for the specific language governing permissions and
18 : * limitations under the License.
19 : */
20 : #ifndef MIO_ODE_SECIRTS_PARAMETER_SPACE_H
21 : #define MIO_ODE_SECIRTS_PARAMETER_SPACE_H
22 :
23 : #include "memilio/mobility/metapopulation_mobility_instant.h"
24 : #include "memilio/utils/memory.h"
25 : #include "memilio/utils/logging.h"
26 : #include "memilio/utils/parameter_distributions.h"
27 : #include "ode_secirts/model.h"
28 :
29 : #include <assert.h>
30 : #include <string>
31 : #include <vector>
32 : #include <random>
33 : #include <memory>
34 :
35 : namespace mio
36 : {
37 : namespace osecirts
38 : {
39 : /**
40 : * Draws a sample from the specified distributions for all parameters
41 : * related to the demographics, e.g., population.
42 : * @tparam FP Floating point type, e.g., double.
43 : * @param[inout] model Model including contact patterns for all age groups
44 : */
45 : template <typename FP = double>
46 117 : void draw_sample_demographics(Model<FP>& model)
47 : {
48 117 : model.parameters.template get<ICUCapacity<FP>>().draw_sample();
49 117 : model.parameters.template get<TestAndTraceCapacity<FP>>().draw_sample();
50 :
51 117 : const static std::vector<InfectionState> naive_immunity_states = {
52 : InfectionState::SusceptibleNaive,
53 : InfectionState::ExposedNaive,
54 : InfectionState::InfectedNoSymptomsNaive,
55 : InfectionState::InfectedNoSymptomsNaiveConfirmed,
56 : InfectionState::InfectedSymptomsNaive,
57 : InfectionState::InfectedSymptomsNaiveConfirmed,
58 : InfectionState::InfectedSevereNaive,
59 : InfectionState::InfectedCriticalNaive,
60 : InfectionState::DeadNaive,
61 : };
62 :
63 117 : const static std::vector<InfectionState> partial_immunity_states = {
64 : InfectionState::SusceptiblePartialImmunity, InfectionState::ExposedPartialImmunity,
65 : InfectionState::InfectedNoSymptomsPartialImmunity, InfectionState::InfectedNoSymptomsPartialImmunityConfirmed,
66 : InfectionState::InfectedSymptomsPartialImmunity, InfectionState::InfectedSymptomsPartialImmunityConfirmed,
67 : InfectionState::InfectedSeverePartialImmunity, InfectionState::InfectedCriticalPartialImmunity,
68 : InfectionState::TemporaryImmunePartialImmunity, InfectionState::DeadPartialImmunity,
69 : };
70 :
71 117 : const static std::vector<InfectionState> improved_immunity_states = {
72 : InfectionState::SusceptibleImprovedImmunity, InfectionState::ExposedImprovedImmunity,
73 : InfectionState::InfectedNoSymptomsImprovedImmunity, InfectionState::InfectedNoSymptomsImprovedImmunityConfirmed,
74 : InfectionState::InfectedSymptomsImprovedImmunity, InfectionState::InfectedSymptomsImprovedImmunityConfirmed,
75 : InfectionState::InfectedSevereImprovedImmunity, InfectionState::InfectedCriticalImprovedImmunity,
76 : InfectionState::TemporaryImmuneImprovedImmunity, InfectionState::DeadImprovedImmunity,
77 : };
78 :
79 : // helper function to calculate the total population of a layer for a given age group
80 3789 : auto calculate_layer_total = [&model](const std::vector<InfectionState>& states, AgeGroup ageGroup) {
81 3672 : return std::accumulate(states.begin(), states.end(), 0.0,
82 70992 : [&model, &ageGroup](double sum, const InfectionState& state) {
83 70992 : return sum + model.populations[{ageGroup, state}];
84 3672 : });
85 : };
86 :
87 : // helper function to adjust the susceptible population of a layer for a given age group
88 5625 : auto adjust_susceptible_population = [&model](AgeGroup i, double diff, InfectionState susceptibleState) {
89 1836 : model.populations[{i, susceptibleState}] += diff;
90 1836 : if (model.populations[{i, susceptibleState}] < 0) {
91 0 : mio::log_warning("Negative population in State " + std::to_string(static_cast<size_t>(susceptibleState)) +
92 : " for age group " + std::to_string(static_cast<size_t>(i)) + ". Setting to 0.");
93 0 : model.populations[{i, susceptibleState}] = 0;
94 : }
95 : };
96 :
97 729 : for (auto i = AgeGroup(0); i < model.parameters.get_num_groups(); i++) {
98 :
99 612 : const double group_naive_total = calculate_layer_total(naive_immunity_states, i);
100 612 : const double group_partial_total = calculate_layer_total(partial_immunity_states, i);
101 612 : const double group_improved_total = calculate_layer_total(improved_immunity_states, i);
102 :
103 : //sample initial compartments (with exceptions)
104 18360 : for (auto inf_state = Index<InfectionState>(0); inf_state < InfectionState::Count; ++inf_state) {
105 35496 : if (inf_state != InfectionState::DeadNaive && //not sampled, fixed from data
106 69768 : inf_state != InfectionState::DeadPartialImmunity && //not sampled, fixed from data
107 34272 : inf_state != InfectionState::DeadImprovedImmunity) { //not sampled, fixed from data
108 15912 : model.populations[{i, inf_state}].draw_sample();
109 : }
110 : }
111 612 : const double diff_naive = group_naive_total - calculate_layer_total(naive_immunity_states, i);
112 612 : const double diff_partial = group_partial_total - calculate_layer_total(partial_immunity_states, i);
113 612 : const double diff_improved = group_improved_total - calculate_layer_total(improved_immunity_states, i);
114 :
115 612 : adjust_susceptible_population(i, diff_naive, InfectionState::SusceptibleNaive);
116 612 : adjust_susceptible_population(i, diff_partial, InfectionState::SusceptiblePartialImmunity);
117 612 : adjust_susceptible_population(i, diff_improved, InfectionState::SusceptibleImprovedImmunity);
118 : }
119 117 : }
120 :
121 : /**
122 : * Draws a sample from the specified distributions for all parameters
123 : * related to the infection.
124 : *
125 : * @tparam FP Floating point type, e.g., double.
126 : * @param[inout] model Model including contact patterns for all age groups.
127 : */
128 : template <typename FP = double>
129 108 : void draw_sample_infection(Model<FP>& model)
130 : {
131 108 : model.parameters.template get<Seasonality<FP>>().draw_sample();
132 :
133 : //not age dependent
134 108 : model.parameters.template get<TimeExposed<FP>>()[AgeGroup(0)].draw_sample();
135 108 : model.parameters.template get<TimeInfectedNoSymptoms<FP>>()[AgeGroup(0)].draw_sample();
136 108 : model.parameters.template get<RelativeTransmissionNoSymptoms<FP>>()[AgeGroup(0)].draw_sample();
137 108 : model.parameters.template get<RiskOfInfectionFromSymptomatic<FP>>()[AgeGroup(0)].draw_sample();
138 108 : model.parameters.template get<MaxRiskOfInfectionFromSymptomatic<FP>>()[AgeGroup(0)].draw_sample();
139 108 : model.parameters.template get<TimeTemporaryImmunityPI<FP>>()[AgeGroup(0)].draw_sample();
140 108 : model.parameters.template get<TimeTemporaryImmunityII<FP>>()[AgeGroup(0)].draw_sample();
141 :
142 108 : model.parameters.template get<ReducExposedPartialImmunity<FP>>()[AgeGroup(0)].draw_sample();
143 108 : model.parameters.template get<ReducExposedImprovedImmunity<FP>>()[AgeGroup(0)].draw_sample();
144 108 : model.parameters.template get<ReducInfectedSymptomsPartialImmunity<FP>>()[AgeGroup(0)].draw_sample();
145 108 : model.parameters.template get<ReducInfectedSymptomsImprovedImmunity<FP>>()[AgeGroup(0)].draw_sample();
146 108 : model.parameters.template get<ReducInfectedSevereCriticalDeadPartialImmunity<FP>>()[AgeGroup(0)].draw_sample();
147 108 : model.parameters.template get<ReducInfectedSevereCriticalDeadImprovedImmunity<FP>>()[AgeGroup(0)].draw_sample();
148 108 : model.parameters.template get<ReducTimeInfectedMild<FP>>()[AgeGroup(0)].draw_sample();
149 :
150 666 : for (auto i = AgeGroup(0); i < model.parameters.get_num_groups(); i++) {
151 : //not age dependent
152 558 : model.parameters.template get<TimeExposed<FP>>()[i] =
153 1116 : model.parameters.template get<TimeExposed<FP>>()[AgeGroup(0)];
154 558 : model.parameters.template get<TimeInfectedNoSymptoms<FP>>()[i] =
155 1116 : model.parameters.template get<TimeInfectedNoSymptoms<FP>>()[AgeGroup(0)];
156 558 : model.parameters.template get<RelativeTransmissionNoSymptoms<FP>>()[i] =
157 1116 : model.parameters.template get<RelativeTransmissionNoSymptoms<FP>>()[AgeGroup(0)];
158 558 : model.parameters.template get<RiskOfInfectionFromSymptomatic<FP>>()[i] =
159 1116 : model.parameters.template get<RiskOfInfectionFromSymptomatic<FP>>()[AgeGroup(0)];
160 558 : model.parameters.template get<MaxRiskOfInfectionFromSymptomatic<FP>>()[i] =
161 1116 : model.parameters.template get<MaxRiskOfInfectionFromSymptomatic<FP>>()[AgeGroup(0)];
162 :
163 558 : model.parameters.template get<ReducExposedPartialImmunity<FP>>()[i] =
164 1116 : model.parameters.template get<ReducExposedPartialImmunity<FP>>()[AgeGroup(0)];
165 558 : model.parameters.template get<ReducExposedImprovedImmunity<FP>>()[i] =
166 1116 : model.parameters.template get<ReducExposedImprovedImmunity<FP>>()[AgeGroup(0)];
167 558 : model.parameters.template get<ReducInfectedSymptomsPartialImmunity<FP>>()[i] =
168 1116 : model.parameters.template get<ReducInfectedSymptomsPartialImmunity<FP>>()[AgeGroup(0)];
169 558 : model.parameters.template get<ReducInfectedSymptomsImprovedImmunity<FP>>()[i] =
170 1116 : model.parameters.template get<ReducInfectedSymptomsImprovedImmunity<FP>>()[AgeGroup(0)];
171 558 : model.parameters.template get<ReducInfectedSevereCriticalDeadPartialImmunity<FP>>()[i] =
172 1116 : model.parameters.template get<ReducInfectedSevereCriticalDeadPartialImmunity<FP>>()[AgeGroup(0)];
173 558 : model.parameters.template get<ReducInfectedSevereCriticalDeadImprovedImmunity<FP>>()[i] =
174 1116 : model.parameters.template get<ReducInfectedSevereCriticalDeadImprovedImmunity<FP>>()[AgeGroup(0)];
175 558 : model.parameters.template get<ReducTimeInfectedMild<FP>>()[i] =
176 1116 : model.parameters.template get<ReducTimeInfectedMild<FP>>()[AgeGroup(0)];
177 :
178 : //age dependent
179 558 : model.parameters.template get<TimeInfectedSymptoms<FP>>()[i].draw_sample();
180 558 : model.parameters.template get<TimeInfectedSevere<FP>>()[i].draw_sample();
181 558 : model.parameters.template get<TimeInfectedCritical<FP>>()[i].draw_sample();
182 :
183 558 : model.parameters.template get<TransmissionProbabilityOnContact<FP>>()[i].draw_sample();
184 558 : model.parameters.template get<RecoveredPerInfectedNoSymptoms<FP>>()[i].draw_sample();
185 558 : model.parameters.template get<DeathsPerCritical<FP>>()[i].draw_sample();
186 558 : model.parameters.template get<SeverePerInfectedSymptoms<FP>>()[i].draw_sample();
187 558 : model.parameters.template get<CriticalPerSevere<FP>>()[i].draw_sample();
188 : }
189 108 : }
190 :
191 : /**
192 : * Draws a sample from model parameter distributions and stores sample values
193 : * as parameters values (cf. UncertainValue and Parameters classes).
194 : *
195 : * @tparam FP Floating point type, e.g., double.
196 : * @param[inout] model Model including contact patterns for all age groups.
197 : */
198 : template <typename FP = double>
199 9 : void draw_sample(Model<FP>& model)
200 : {
201 9 : draw_sample_infection(model);
202 9 : draw_sample_demographics(model);
203 9 : model.parameters.template get<ContactPatterns<FP>>().draw_sample();
204 9 : model.apply_constraints();
205 9 : }
206 :
207 : /**
208 : * Draws samples for each model node in a graph.
209 : * Some parameters are shared between nodes and are only sampled once.
210 : *
211 : * @tparam FP Floating point type, e.g., double.
212 : * @param graph Graph to be sampled.
213 : * @return Graph with nodes and edges from the input graph sampled.
214 : */
215 : template <typename FP = double>
216 99 : Graph<Model<FP>, MobilityParameters<FP>> draw_sample(Graph<Model<FP>, MobilityParameters<FP>>& graph)
217 : {
218 99 : Graph<Model<FP>, MobilityParameters<FP>> sampled_graph;
219 :
220 : //sample global parameters
221 99 : auto& shared_params_model = graph.nodes()[0].property;
222 99 : draw_sample_infection(shared_params_model);
223 99 : auto& shared_contacts = shared_params_model.parameters.template get<ContactPatterns<FP>>();
224 99 : shared_contacts.draw_sample_dampings();
225 99 : auto& shared_dynamic_npis = shared_params_model.parameters.template get<DynamicNPIsInfectedSymptoms<FP>>();
226 99 : shared_dynamic_npis.draw_sample();
227 :
228 315 : for (auto& params_node : graph.nodes()) {
229 108 : auto& node_model = params_node.property;
230 :
231 : //sample local parameters
232 108 : draw_sample_demographics(params_node.property);
233 :
234 : //copy global parameters
235 : //save demographic parameters so they aren't overwritten
236 108 : auto local_icu_capacity = node_model.parameters.template get<ICUCapacity<FP>>();
237 108 : auto local_tnt_capacity = node_model.parameters.template get<TestAndTraceCapacity<FP>>();
238 108 : auto local_holidays = node_model.parameters.template get<ContactPatterns<FP>>().get_school_holidays();
239 108 : auto local_daily_v1 = node_model.parameters.template get<DailyPartialVaccinations<FP>>();
240 108 : auto local_daily_v2 = node_model.parameters.template get<DailyFullVaccinations<FP>>();
241 108 : auto local_daily_v3 = node_model.parameters.template get<DailyBoosterVaccinations<FP>>();
242 108 : node_model.parameters = shared_params_model.parameters;
243 108 : node_model.parameters.template get<ICUCapacity<FP>>() = local_icu_capacity;
244 108 : node_model.parameters.template get<TestAndTraceCapacity<FP>>() = local_tnt_capacity;
245 108 : node_model.parameters.template get<ContactPatterns<FP>>().get_school_holidays() = local_holidays;
246 108 : node_model.parameters.template get<DailyPartialVaccinations<FP>>() = local_daily_v1;
247 108 : node_model.parameters.template get<DailyFullVaccinations<FP>>() = local_daily_v2;
248 108 : node_model.parameters.template get<DailyBoosterVaccinations<FP>>() = local_daily_v3;
249 :
250 108 : node_model.parameters.template get<ContactPatterns<FP>>().make_matrix();
251 108 : node_model.apply_constraints();
252 :
253 108 : sampled_graph.add_node(params_node.id, node_model);
254 : }
255 :
256 117 : for (auto& edge : graph.edges()) {
257 9 : auto edge_params = edge.property;
258 : //no dynamic NPIs
259 : //TODO: add switch to optionally enable dynamic NPIs to edges
260 9 : sampled_graph.add_edge(edge.start_node_idx, edge.end_node_idx, edge_params);
261 : }
262 :
263 198 : return sampled_graph;
264 99 : }
265 :
266 : } // namespace osecirts
267 : } // namespace mio
268 :
269 : #endif // MIO_ODE_SECIRTS_PARAMETER_SPACE_H
|