Line data Source code
1 : /*
2 : * Copyright (C) 2020-2024 MEmilio
3 : *
4 : * Authors: Daniel Abele, Jan Kleinert, Martin J. Kuehn
5 : *
6 : * Contact: Martin J. Kuehn <Martin.Kuehn@DLR.de>
7 : *
8 : * Licensed under the Apache License, Version 2.0 (the "License");
9 : * you may not use this file except in compliance with the License.
10 : * You may obtain a copy of the License at
11 : *
12 : * http://www.apache.org/licenses/LICENSE-2.0
13 : *
14 : * Unless required by applicable law or agreed to in writing, software
15 : * distributed under the License is distributed on an "AS IS" BASIS,
16 : * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 : * See the License for the specific language governing permissions and
18 : * limitations under the License.
19 : */
20 : #ifndef SEIR_MODEL_H
21 : #define SEIR_MODEL_H
22 :
23 : #include "memilio/compartments/flow_model.h"
24 : #include "memilio/config.h"
25 : #include "memilio/epidemiology/age_group.h"
26 : #include "memilio/epidemiology/populations.h"
27 : #include "memilio/math/interpolation.h"
28 : #include "memilio/utils/time_series.h"
29 : #include "ode_seir/infection_state.h"
30 : #include "ode_seir/parameters.h"
31 :
32 : GCC_CLANG_DIAGNOSTIC(push)
33 : GCC_CLANG_DIAGNOSTIC(ignored "-Wshadow")
34 : #include <Eigen/Dense>
35 : GCC_CLANG_DIAGNOSTIC(pop)
36 :
37 : namespace mio
38 : {
39 : namespace oseir
40 : {
41 :
42 : /********************
43 : * define the model *
44 : ********************/
45 :
46 : // clang-format off
47 : using Flows = TypeList<Flow<InfectionState::Susceptible, InfectionState::Exposed>,
48 : Flow<InfectionState::Exposed, InfectionState::Infected>,
49 : Flow<InfectionState::Infected, InfectionState::Recovered>>;
50 : // clang-format on
51 : template <typename FP = ScalarType>
52 : class Model
53 : : public FlowModel<FP, InfectionState, mio::Populations<FP, AgeGroup, InfectionState>, Parameters<FP>, Flows>
54 : {
55 : using Base = FlowModel<FP, InfectionState, mio::Populations<FP, AgeGroup, InfectionState>, Parameters<FP>, Flows>;
56 :
57 : public:
58 : using typename Base::ParameterSet;
59 : using typename Base::Populations;
60 :
61 : Model(const Populations& pop, const ParameterSet& params)
62 : : Base(pop, params)
63 : {
64 : }
65 :
66 147 : Model(int num_agegroups)
67 147 : : Base(Populations({AgeGroup(num_agegroups), InfectionState::Count}), ParameterSet(AgeGroup(num_agegroups)))
68 : {
69 147 : }
70 :
71 386964 : void get_flows(Eigen::Ref<const Vector<FP>> pop, Eigen::Ref<const Vector<FP>> y, FP t,
72 : Eigen::Ref<Vector<FP>> flows) const override
73 : {
74 386964 : const Index<AgeGroup> age_groups = reduce_index<Index<AgeGroup>>(this->populations.size());
75 386964 : const auto& params = this->parameters;
76 :
77 1160892 : for (auto i : make_index_range(age_groups)) {
78 386964 : const size_t Si = this->populations.get_flat_index({i, InfectionState::Susceptible});
79 386964 : const size_t Ei = this->populations.get_flat_index({i, InfectionState::Exposed});
80 386964 : const size_t Ii = this->populations.get_flat_index({i, InfectionState::Infected});
81 :
82 1160892 : for (auto j : make_index_range(age_groups)) {
83 386964 : const size_t Sj = this->populations.get_flat_index({i, InfectionState::Susceptible});
84 386964 : const size_t Ej = this->populations.get_flat_index({j, InfectionState::Exposed});
85 386964 : const size_t Ij = this->populations.get_flat_index({j, InfectionState::Infected});
86 386964 : const size_t Rj = this->populations.get_flat_index({j, InfectionState::Recovered});
87 :
88 386964 : const ScalarType Nj = pop[Sj] + pop[Ej] + pop[Ij] + pop[Rj];
89 386964 : const ScalarType divNj = (Nj < Limits<ScalarType>::zero_tolerance()) ? 0.0 : 1.0 / Nj;
90 386964 : const ScalarType coeffStoE =
91 773928 : params.template get<ContactPatterns<FP>>().get_cont_freq_mat().get_matrix_at(t)(i.get(), j.get()) *
92 386964 : params.template get<TransmissionProbabilityOnContact<FP>>()[i] * divNj;
93 :
94 386964 : flows[Base::template get_flat_flow_index<InfectionState::Susceptible, InfectionState::Exposed>(i)] +=
95 386964 : coeffStoE * y[Si] * pop[Ij];
96 : }
97 386964 : flows[Base::template get_flat_flow_index<InfectionState::Exposed, InfectionState::Infected>(i)] =
98 386964 : (1.0 / params.template get<TimeExposed<FP>>()[i]) * y[Ei];
99 386964 : flows[Base::template get_flat_flow_index<InfectionState::Infected, InfectionState::Recovered>(i)] =
100 386964 : (1.0 / params.template get<TimeInfected<FP>>()[i]) * y[Ii];
101 : }
102 386964 : }
103 :
104 : /**
105 : *@brief Computes the reproduction number at a given index time of the Model output obtained by the Simulation.
106 : *@param t_idx The index time at which the reproduction number is computed.
107 : *@param y The TimeSeries obtained from the Model Simulation.
108 : *@returns The computed reproduction number at the provided index time.
109 : */
110 333 : IOResult<ScalarType> get_reproduction_number(size_t t_idx, const mio::TimeSeries<ScalarType>& y)
111 : {
112 333 : if (!(t_idx < static_cast<size_t>(y.get_num_time_points()))) {
113 9 : return mio::failure(mio::StatusCode::OutOfRange, "t_idx is not a valid index for the TimeSeries");
114 : }
115 :
116 324 : auto const& params = this->parameters;
117 :
118 324 : const size_t num_groups = (size_t)params.get_num_groups();
119 324 : constexpr size_t num_infected_compartments = 2;
120 324 : const size_t total_infected_compartments = num_infected_compartments * num_groups;
121 :
122 324 : ContactMatrixGroup const& contact_matrix = params.template get<ContactPatterns<ScalarType>>();
123 :
124 324 : Eigen::MatrixXd F = Eigen::MatrixXd::Zero(total_infected_compartments, total_infected_compartments);
125 324 : Eigen::MatrixXd V = Eigen::MatrixXd::Zero(total_infected_compartments, total_infected_compartments);
126 :
127 648 : for (auto i = AgeGroup(0); i < AgeGroup(num_groups); i++) {
128 324 : size_t Si = this->populations.get_flat_index({i, InfectionState::Susceptible});
129 648 : for (auto j = AgeGroup(0); j < AgeGroup(num_groups); j++) {
130 :
131 324 : const ScalarType Nj = this->populations.get_group_total(j);
132 324 : const ScalarType divNj = (Nj < 1e-12) ? 0.0 : 1.0 / Nj;
133 :
134 648 : ScalarType coeffStoE = contact_matrix.get_matrix_at(y.get_time(t_idx))(i.get(), j.get()) *
135 324 : params.template get<TransmissionProbabilityOnContact<ScalarType>>()[i] * divNj;
136 324 : F((size_t)i, (size_t)j + num_groups) = coeffStoE * y.get_value(t_idx)[Si];
137 : }
138 :
139 324 : ScalarType T_Ei = params.template get<mio::oseir::TimeExposed<ScalarType>>()[i];
140 324 : ScalarType T_Ii = params.template get<mio::oseir::TimeInfected<ScalarType>>()[i];
141 324 : V((size_t)i, (size_t)i) = 1.0 / T_Ei;
142 324 : V((size_t)i + num_groups, (size_t)i) = -1.0 / T_Ei;
143 324 : V((size_t)i + num_groups, (size_t)i + num_groups) = 1.0 / T_Ii;
144 : }
145 :
146 324 : V = V.inverse();
147 :
148 324 : Eigen::MatrixXd NextGenMatrix = Eigen::MatrixXd::Zero(total_infected_compartments, total_infected_compartments);
149 324 : NextGenMatrix = F * V;
150 :
151 : //Compute the largest eigenvalue in absolute value
152 324 : Eigen::ComplexEigenSolver<Eigen::MatrixXd> ces;
153 :
154 324 : ces.compute(NextGenMatrix);
155 324 : const Eigen::VectorXcd eigen_vals = ces.eigenvalues();
156 :
157 324 : Eigen::VectorXd eigen_vals_abs;
158 324 : eigen_vals_abs.resize(eigen_vals.size());
159 :
160 972 : for (int i = 0; i < eigen_vals.size(); i++) {
161 648 : eigen_vals_abs[i] = std::abs(eigen_vals[i]);
162 : }
163 324 : return mio::success(eigen_vals_abs.maxCoeff());
164 324 : }
165 :
166 : /**
167 : *@brief Computes the reproduction number for all time points of the Model output obtained by the Simulation.
168 : *@param y The TimeSeries obtained from the Model Simulation.
169 : *@returns vector containing all reproduction numbers
170 : */
171 27 : Eigen::VectorXd get_reproduction_numbers(const mio::TimeSeries<ScalarType>& y)
172 : {
173 27 : auto num_time_points = y.get_num_time_points();
174 27 : Eigen::VectorXd temp(num_time_points);
175 216 : for (size_t i = 0; i < static_cast<size_t>(num_time_points); i++) {
176 189 : temp[i] = get_reproduction_number(i, y).value();
177 : }
178 54 : return temp;
179 27 : }
180 :
181 : /**
182 : *@brief Computes the reproduction number at a given time point of the Model output obtained by the Simulation. If the particular time point is not inside the output, a linearly interpolated value is returned.
183 : *@param t_value The time point at which the reproduction number is computed.
184 : *@param y The TimeSeries obtained from the Model Simulation.
185 : *@returns The computed reproduction number at the provided time point, potentially using linear interpolation.
186 : */
187 99 : IOResult<ScalarType> get_reproduction_number(ScalarType t_value, const mio::TimeSeries<ScalarType>& y)
188 : {
189 99 : if (t_value < y.get_time(0) || t_value > y.get_last_time()) {
190 : return mio::failure(mio::StatusCode::OutOfRange,
191 27 : "Cannot interpolate reproduction number outside computed horizon of the TimeSeries");
192 : }
193 :
194 72 : if (t_value == y.get_time(0)) {
195 18 : return mio::success(get_reproduction_number((size_t)0, y).value());
196 : }
197 :
198 108 : auto times = std::vector<ScalarType>(y.get_times().begin(), y.get_times().end());
199 :
200 54 : auto time_late = std::distance(times.begin(), std::lower_bound(times.begin(), times.end(), t_value));
201 :
202 54 : ScalarType y1 = get_reproduction_number(static_cast<size_t>(time_late - 1), y).value();
203 54 : ScalarType y2 = get_reproduction_number(static_cast<size_t>(time_late), y).value();
204 :
205 54 : auto result = linear_interpolation(t_value, y.get_time(time_late - 1), y.get_time(time_late), y1, y2);
206 54 : return mio::success(static_cast<ScalarType>(result));
207 54 : }
208 :
209 : /**
210 : * serialize this.
211 : * @see mio::serialize
212 : */
213 : template <class IOContext>
214 : void serialize(IOContext& io) const
215 : {
216 : auto obj = io.create_object("Model");
217 : obj.add_element("Parameters", this->parameters);
218 : obj.add_element("Populations", this->populations);
219 : }
220 :
221 : /**
222 : * deserialize an object of this class.
223 : * @see mio::deserialize
224 : */
225 : template <class IOContext>
226 : static IOResult<Model> deserialize(IOContext& io)
227 : {
228 : auto obj = io.expect_object("Model");
229 : auto par = obj.expect_element("Parameters", Tag<ParameterSet>{});
230 : auto pop = obj.expect_element("Populations", Tag<Populations>{});
231 : return apply(
232 : io,
233 : [](auto&& par_, auto&& pop_) {
234 : return Model{pop_, par_};
235 : },
236 : par, pop);
237 : }
238 : };
239 :
240 : } // namespace oseir
241 : } // namespace mio
242 :
243 : #endif // SEIR_MODEL_H
|