LCOV - code coverage report
Current view: top level - models/ode_seir - model.h (source / functions) Hit Total Coverage
Test: coverage.info Lines: 81 81 100.0 %
Date: 2024-11-18 12:45:26 Functions: 5 5 100.0 %

          Line data    Source code
       1             : /* 
       2             : * Copyright (C) 2020-2024 MEmilio
       3             : *
       4             : * Authors: Daniel Abele, Jan Kleinert, Martin J. Kuehn
       5             : *
       6             : * Contact: Martin J. Kuehn <Martin.Kuehn@DLR.de>
       7             : *
       8             : * Licensed under the Apache License, Version 2.0 (the "License");
       9             : * you may not use this file except in compliance with the License.
      10             : * You may obtain a copy of the License at
      11             : *
      12             : *     http://www.apache.org/licenses/LICENSE-2.0
      13             : *
      14             : * Unless required by applicable law or agreed to in writing, software
      15             : * distributed under the License is distributed on an "AS IS" BASIS,
      16             : * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      17             : * See the License for the specific language governing permissions and
      18             : * limitations under the License.
      19             : */
      20             : #ifndef SEIR_MODEL_H
      21             : #define SEIR_MODEL_H
      22             : 
      23             : #include "memilio/compartments/flow_model.h"
      24             : #include "memilio/config.h"
      25             : #include "memilio/epidemiology/age_group.h"
      26             : #include "memilio/epidemiology/populations.h"
      27             : #include "memilio/math/interpolation.h"
      28             : #include "memilio/utils/time_series.h"
      29             : #include "ode_seir/infection_state.h"
      30             : #include "ode_seir/parameters.h"
      31             : 
      32             : GCC_CLANG_DIAGNOSTIC(push)
      33             : GCC_CLANG_DIAGNOSTIC(ignored "-Wshadow")
      34             : #include <Eigen/Dense>
      35             : GCC_CLANG_DIAGNOSTIC(pop)
      36             : 
      37             : namespace mio
      38             : {
      39             : namespace oseir
      40             : {
      41             : 
      42             : /********************
      43             :  * define the model *
      44             :  ********************/
      45             : 
      46             : // clang-format off
      47             : using Flows = TypeList<Flow<InfectionState::Susceptible, InfectionState::Exposed>,
      48             :                        Flow<InfectionState::Exposed,     InfectionState::Infected>,
      49             :                        Flow<InfectionState::Infected,    InfectionState::Recovered>>;
      50             : // clang-format on
      51             : template <typename FP = ScalarType>
      52             : class Model
      53             :     : public FlowModel<FP, InfectionState, mio::Populations<FP, AgeGroup, InfectionState>, Parameters<FP>, Flows>
      54             : {
      55             :     using Base = FlowModel<FP, InfectionState, mio::Populations<FP, AgeGroup, InfectionState>, Parameters<FP>, Flows>;
      56             : 
      57             : public:
      58             :     using typename Base::ParameterSet;
      59             :     using typename Base::Populations;
      60             : 
      61             :     Model(const Populations& pop, const ParameterSet& params)
      62             :         : Base(pop, params)
      63             :     {
      64             :     }
      65             : 
      66         147 :     Model(int num_agegroups)
      67         147 :         : Base(Populations({AgeGroup(num_agegroups), InfectionState::Count}), ParameterSet(AgeGroup(num_agegroups)))
      68             :     {
      69         147 :     }
      70             : 
      71      386964 :     void get_flows(Eigen::Ref<const Vector<FP>> pop, Eigen::Ref<const Vector<FP>> y, FP t,
      72             :                    Eigen::Ref<Vector<FP>> flows) const override
      73             :     {
      74      386964 :         const Index<AgeGroup> age_groups = reduce_index<Index<AgeGroup>>(this->populations.size());
      75      386964 :         const auto& params               = this->parameters;
      76             : 
      77     1160892 :         for (auto i : make_index_range(age_groups)) {
      78      386964 :             const size_t Si = this->populations.get_flat_index({i, InfectionState::Susceptible});
      79      386964 :             const size_t Ei = this->populations.get_flat_index({i, InfectionState::Exposed});
      80      386964 :             const size_t Ii = this->populations.get_flat_index({i, InfectionState::Infected});
      81             : 
      82     1160892 :             for (auto j : make_index_range(age_groups)) {
      83      386964 :                 const size_t Sj = this->populations.get_flat_index({i, InfectionState::Susceptible});
      84      386964 :                 const size_t Ej = this->populations.get_flat_index({j, InfectionState::Exposed});
      85      386964 :                 const size_t Ij = this->populations.get_flat_index({j, InfectionState::Infected});
      86      386964 :                 const size_t Rj = this->populations.get_flat_index({j, InfectionState::Recovered});
      87             : 
      88      386964 :                 const ScalarType Nj    = pop[Sj] + pop[Ej] + pop[Ij] + pop[Rj];
      89      386964 :                 const ScalarType divNj = (Nj < Limits<ScalarType>::zero_tolerance()) ? 0.0 : 1.0 / Nj;
      90      386964 :                 const ScalarType coeffStoE =
      91      773928 :                     params.template get<ContactPatterns<FP>>().get_cont_freq_mat().get_matrix_at(t)(i.get(), j.get()) *
      92      386964 :                     params.template get<TransmissionProbabilityOnContact<FP>>()[i] * divNj;
      93             : 
      94      386964 :                 flows[Base::template get_flat_flow_index<InfectionState::Susceptible, InfectionState::Exposed>(i)] +=
      95      386964 :                     coeffStoE * y[Si] * pop[Ij];
      96             :             }
      97      386964 :             flows[Base::template get_flat_flow_index<InfectionState::Exposed, InfectionState::Infected>(i)] =
      98      386964 :                 (1.0 / params.template get<TimeExposed<FP>>()[i]) * y[Ei];
      99      386964 :             flows[Base::template get_flat_flow_index<InfectionState::Infected, InfectionState::Recovered>(i)] =
     100      386964 :                 (1.0 / params.template get<TimeInfected<FP>>()[i]) * y[Ii];
     101             :         }
     102      386964 :     }
     103             : 
     104             :     /**
     105             :     *@brief Computes the reproduction number at a given index time of the Model output obtained by the Simulation.
     106             :     *@param t_idx The index time at which the reproduction number is computed.
     107             :     *@param y The TimeSeries obtained from the Model Simulation.
     108             :     *@returns The computed reproduction number at the provided index time.
     109             :     */
     110         333 :     IOResult<ScalarType> get_reproduction_number(size_t t_idx, const mio::TimeSeries<ScalarType>& y)
     111             :     {
     112         333 :         if (!(t_idx < static_cast<size_t>(y.get_num_time_points()))) {
     113           9 :             return mio::failure(mio::StatusCode::OutOfRange, "t_idx is not a valid index for the TimeSeries");
     114             :         }
     115             : 
     116         324 :         auto const& params = this->parameters;
     117             : 
     118         324 :         const size_t num_groups                    = (size_t)params.get_num_groups();
     119         324 :         constexpr size_t num_infected_compartments = 2;
     120         324 :         const size_t total_infected_compartments   = num_infected_compartments * num_groups;
     121             : 
     122         324 :         ContactMatrixGroup const& contact_matrix = params.template get<ContactPatterns<ScalarType>>();
     123             : 
     124         324 :         Eigen::MatrixXd F = Eigen::MatrixXd::Zero(total_infected_compartments, total_infected_compartments);
     125         324 :         Eigen::MatrixXd V = Eigen::MatrixXd::Zero(total_infected_compartments, total_infected_compartments);
     126             : 
     127         648 :         for (auto i = AgeGroup(0); i < AgeGroup(num_groups); i++) {
     128         324 :             size_t Si = this->populations.get_flat_index({i, InfectionState::Susceptible});
     129         648 :             for (auto j = AgeGroup(0); j < AgeGroup(num_groups); j++) {
     130             : 
     131         324 :                 const ScalarType Nj    = this->populations.get_group_total(j);
     132         324 :                 const ScalarType divNj = (Nj < 1e-12) ? 0.0 : 1.0 / Nj;
     133             : 
     134         648 :                 ScalarType coeffStoE = contact_matrix.get_matrix_at(y.get_time(t_idx))(i.get(), j.get()) *
     135         324 :                                        params.template get<TransmissionProbabilityOnContact<ScalarType>>()[i] * divNj;
     136         324 :                 F((size_t)i, (size_t)j + num_groups) = coeffStoE * y.get_value(t_idx)[Si];
     137             :             }
     138             : 
     139         324 :             ScalarType T_Ei                      = params.template get<mio::oseir::TimeExposed<ScalarType>>()[i];
     140         324 :             ScalarType T_Ii                      = params.template get<mio::oseir::TimeInfected<ScalarType>>()[i];
     141         324 :             V((size_t)i, (size_t)i)              = 1.0 / T_Ei;
     142         324 :             V((size_t)i + num_groups, (size_t)i) = -1.0 / T_Ei;
     143         324 :             V((size_t)i + num_groups, (size_t)i + num_groups) = 1.0 / T_Ii;
     144             :         }
     145             : 
     146         324 :         V = V.inverse();
     147             : 
     148         324 :         Eigen::MatrixXd NextGenMatrix = Eigen::MatrixXd::Zero(total_infected_compartments, total_infected_compartments);
     149         324 :         NextGenMatrix                 = F * V;
     150             : 
     151             :         //Compute the largest eigenvalue in absolute value
     152         324 :         Eigen::ComplexEigenSolver<Eigen::MatrixXd> ces;
     153             : 
     154         324 :         ces.compute(NextGenMatrix);
     155         324 :         const Eigen::VectorXcd eigen_vals = ces.eigenvalues();
     156             : 
     157         324 :         Eigen::VectorXd eigen_vals_abs;
     158         324 :         eigen_vals_abs.resize(eigen_vals.size());
     159             : 
     160         972 :         for (int i = 0; i < eigen_vals.size(); i++) {
     161         648 :             eigen_vals_abs[i] = std::abs(eigen_vals[i]);
     162             :         }
     163         324 :         return mio::success(eigen_vals_abs.maxCoeff());
     164         324 :     }
     165             : 
     166             :     /**
     167             :     *@brief Computes the reproduction number for all time points of the Model output obtained by the Simulation.
     168             :     *@param y The TimeSeries obtained from the Model Simulation.
     169             :     *@returns vector containing all reproduction numbers
     170             :     */
     171          27 :     Eigen::VectorXd get_reproduction_numbers(const mio::TimeSeries<ScalarType>& y)
     172             :     {
     173          27 :         auto num_time_points = y.get_num_time_points();
     174          27 :         Eigen::VectorXd temp(num_time_points);
     175         216 :         for (size_t i = 0; i < static_cast<size_t>(num_time_points); i++) {
     176         189 :             temp[i] = get_reproduction_number(i, y).value();
     177             :         }
     178          54 :         return temp;
     179          27 :     }
     180             : 
     181             :     /**
     182             :     *@brief Computes the reproduction number at a given time point of the Model output obtained by the Simulation. If the particular time point is not inside the output, a linearly interpolated value is returned.
     183             :     *@param t_value The time point at which the reproduction number is computed.
     184             :     *@param y The TimeSeries obtained from the Model Simulation.
     185             :     *@returns The computed reproduction number at the provided time point, potentially using linear interpolation.
     186             :     */
     187          99 :     IOResult<ScalarType> get_reproduction_number(ScalarType t_value, const mio::TimeSeries<ScalarType>& y)
     188             :     {
     189          99 :         if (t_value < y.get_time(0) || t_value > y.get_last_time()) {
     190             :             return mio::failure(mio::StatusCode::OutOfRange,
     191          27 :                                 "Cannot interpolate reproduction number outside computed horizon of the TimeSeries");
     192             :         }
     193             : 
     194          72 :         if (t_value == y.get_time(0)) {
     195          18 :             return mio::success(get_reproduction_number((size_t)0, y).value());
     196             :         }
     197             : 
     198         108 :         auto times = std::vector<ScalarType>(y.get_times().begin(), y.get_times().end());
     199             : 
     200          54 :         auto time_late = std::distance(times.begin(), std::lower_bound(times.begin(), times.end(), t_value));
     201             : 
     202          54 :         ScalarType y1 = get_reproduction_number(static_cast<size_t>(time_late - 1), y).value();
     203          54 :         ScalarType y2 = get_reproduction_number(static_cast<size_t>(time_late), y).value();
     204             : 
     205          54 :         auto result = linear_interpolation(t_value, y.get_time(time_late - 1), y.get_time(time_late), y1, y2);
     206          54 :         return mio::success(static_cast<ScalarType>(result));
     207          54 :     }
     208             : 
     209             :     /**
     210             :      * serialize this. 
     211             :      * @see mio::serialize
     212             :      */
     213             :     template <class IOContext>
     214             :     void serialize(IOContext& io) const
     215             :     {
     216             :         auto obj = io.create_object("Model");
     217             :         obj.add_element("Parameters", this->parameters);
     218             :         obj.add_element("Populations", this->populations);
     219             :     }
     220             : 
     221             :     /**
     222             :      * deserialize an object of this class.
     223             :      * @see mio::deserialize
     224             :      */
     225             :     template <class IOContext>
     226             :     static IOResult<Model> deserialize(IOContext& io)
     227             :     {
     228             :         auto obj = io.expect_object("Model");
     229             :         auto par = obj.expect_element("Parameters", Tag<ParameterSet>{});
     230             :         auto pop = obj.expect_element("Populations", Tag<Populations>{});
     231             :         return apply(
     232             :             io,
     233             :             [](auto&& par_, auto&& pop_) {
     234             :                 return Model{pop_, par_};
     235             :             },
     236             :             par, pop);
     237             :     }
     238             : };
     239             : 
     240             : } // namespace oseir
     241             : } // namespace mio
     242             : 
     243             : #endif // SEIR_MODEL_H

Generated by: LCOV version 1.14