LCOV - code coverage report
Current view: top level - models/smm - simulation.h (source / functions) Hit Total Coverage
Test: coverage.info Lines: 72 72 100.0 %
Date: 2025-04-03 12:28:53 Functions: 10 19 52.6 %

          Line data    Source code
       1             : /* 
       2             : * Copyright (C) 2020-2025 German Aerospace Center (DLR-SC)
       3             : *
       4             : * Authors: René Schmieding, Julia Bicker
       5             : *
       6             : * Contact: Martin J. Kuehn <Martin.Kuehn@DLR.de>
       7             : *
       8             : * Licensed under the Apache License, Version 2.0 (the "License");
       9             : * you may not use this file except in compliance with the License.
      10             : * You may obtain a copy of the License at
      11             : *
      12             : *     http://www.apache.org/licenses/LICENSE-2.0
      13             : *
      14             : * Unless required by applicable law or agreed to in writing, software
      15             : * distributed under the License is distributed on an "AS IS" BASIS,
      16             : * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      17             : * See the License for the specific language governing permissions and
      18             : * limitations under the License.
      19             : */
      20             : 
      21             : #ifndef MIO_SMM_SIMULATION_H
      22             : #define MIO_SMM_SIMULATION_H
      23             : 
      24             : #include "memilio/config.h"
      25             : #include "smm/model.h"
      26             : #include "smm/parameters.h"
      27             : #include "memilio/compartments/simulation.h"
      28             : 
      29             : namespace mio
      30             : {
      31             : 
      32             : namespace smm
      33             : {
      34             : 
      35             : /**
      36             :  * @brief A specialized Simulation for mio::smm::Model.
      37             :  * @tparam regions The number of regions.
      38             :  * @tparam Status An infection state enum.
      39             :  */
      40             : template <size_t regions, class Status>
      41             : class Simulation
      42             : {
      43             : public:
      44             : public:
      45             :     using Model = smm::Model<regions, Status>;
      46             : 
      47             :     /**
      48             :      * @brief Set up the simulation for a Stochastic Metapopulation Model.
      49             :      * @param[in] model An instance of mio::smm::Model.
      50             :      * @param[in] t0 Start time.
      51             :      * @param[in] dt Initial Step size.
      52             :      */
      53         102 :     Simulation(Model const& model, ScalarType t0 = 0., ScalarType dt = 1.)
      54         102 :         : m_dt(dt)
      55         102 :         , m_model(std::make_unique<Model>(model))
      56         102 :         , m_result(t0, m_model->get_initial_values())
      57         102 :         , m_internal_time(adoption_rates().size() + transition_rates().size(), t0)
      58         102 :         , m_tp_next_event(adoption_rates().size() + transition_rates().size(), t0)
      59         102 :         , m_waiting_times(adoption_rates().size() + transition_rates().size(), 0)
      60         102 :         , m_current_rates(adoption_rates().size() + transition_rates().size(), 0)
      61             :     {
      62         102 :         assert(dt > 0);
      63         102 :         assert(m_waiting_times.size() > 0);
      64         109 :         assert(std::all_of(adoption_rates().begin(), adoption_rates().end(), [](auto&& r) {
      65             :             return static_cast<size_t>(r.region) < regions;
      66             :         }));
      67         204 :         assert(std::all_of(transition_rates().begin(), transition_rates().end(), [](auto&& r) {
      68             :             return static_cast<size_t>(r.from) < regions && static_cast<size_t>(r.to) < regions;
      69             :         }));
      70             :         // initialize (internal) next event times by random values
      71         211 :         for (size_t i = 0; i < m_tp_next_event.size(); i++) {
      72         109 :             m_tp_next_event[i] += mio::ExponentialDistribution<ScalarType>::get_instance()(m_model->get_rng(), 1.0);
      73             :         }
      74         102 :     }
      75             : 
      76             :     /**
      77             :      * @brief Advance simulation to tmax.
      78             :      * This function performs a Gillespie algorithm.
      79             :      * @param tmax Next stopping point of simulation.
      80             :      */
      81         102 :     Eigen::Ref<Eigen::VectorXd> advance(ScalarType tmax)
      82             :     {
      83         102 :         update_current_rates_and_waiting_times();
      84         102 :         size_t next_event       = determine_next_event(); // index of the next event
      85         102 :         ScalarType current_time = m_result.get_last_time();
      86             :         // set in the past to add a new time point immediately
      87         102 :         ScalarType last_result_time = current_time - m_dt;
      88             :         // iterate over time
      89       18339 :         while (current_time + m_waiting_times[next_event] < tmax) {
      90             :             // update time
      91       18237 :             current_time += m_waiting_times[next_event];
      92             :             // regularily save current state in m_results
      93       18237 :             if (current_time > last_result_time + m_dt) {
      94         102 :                 last_result_time = current_time;
      95         102 :                 m_result.add_time_point(current_time);
      96             :                 // copy from the previous last value
      97         102 :                 m_result.get_last_value() = m_result[m_result.get_num_time_points() - 2];
      98             :             }
      99             :             // decide event type by index and perform it
     100       18237 :             if (next_event < adoption_rates().size()) {
     101             :                 // perform adoption event
     102           1 :                 const auto& rate = adoption_rates()[next_event];
     103           1 :                 m_result.get_last_value()[m_model->populations.get_flat_index({rate.region, rate.from})] -= 1;
     104           1 :                 m_model->populations[{rate.region, rate.from}] -= 1;
     105           1 :                 m_result.get_last_value()[m_model->populations.get_flat_index({rate.region, rate.to})] += 1;
     106           1 :                 m_model->populations[{rate.region, rate.to}] += 1;
     107             :             }
     108             :             else {
     109             :                 // perform transition event
     110       18236 :                 const auto& rate = transition_rates()[next_event - adoption_rates().size()];
     111       18236 :                 m_result.get_last_value()[m_model->populations.get_flat_index({rate.from, rate.status})] -= 1;
     112       18236 :                 m_model->populations[{rate.from, rate.status}] -= 1;
     113       18236 :                 m_result.get_last_value()[m_model->populations.get_flat_index({rate.to, rate.status})] += 1;
     114       18236 :                 m_model->populations[{rate.to, rate.status}] += 1;
     115             :             }
     116             :             // update internal times
     117       36486 :             for (size_t i = 0; i < m_internal_time.size(); i++) {
     118       18249 :                 m_internal_time[i] += m_current_rates[i] * m_waiting_times[next_event];
     119             :             }
     120             :             // draw new "next event" time for the occured event
     121       18237 :             m_tp_next_event[next_event] +=
     122       36474 :                 mio::ExponentialDistribution<ScalarType>::get_instance()(m_model->get_rng(), 1.0);
     123             :             // precalculate next event
     124       18237 :             update_current_rates_and_waiting_times();
     125       18237 :             next_event = determine_next_event();
     126             :         }
     127             :         // copy last result, if no event occurs between last_result_time and tmax
     128         102 :         if (last_result_time < tmax) {
     129         102 :             m_result.add_time_point(tmax);
     130         102 :             m_result.get_last_value() = m_result[m_result.get_num_time_points() - 2];
     131             :         }
     132         102 :         return m_result.get_last_value();
     133             :     }
     134             : 
     135             :     /**
     136             :      * @brief Returns the final simulation result.
     137             :      * @return A TimeSeries to represent the final simulation result.
     138             :      */
     139          13 :     TimeSeries<ScalarType>& get_result()
     140             :     {
     141          13 :         return m_result;
     142             :     }
     143             :     const TimeSeries<ScalarType>& get_result() const
     144             :     {
     145             :         return m_result;
     146             :     }
     147             : 
     148             :     /**
     149             :      * @brief Returns the model used in the simulation.
     150             :      */
     151             :     const Model& get_model() const
     152             :     {
     153             :         return *m_model;
     154             :     }
     155         100 :     Model& get_model()
     156             :     {
     157         100 :         return *m_model;
     158             :     }
     159             : 
     160             : private:
     161             :     /**
     162             :      * @brief Returns the model's transition rates.
     163             :      */
     164       37187 :     inline constexpr const typename smm::TransitionRates<Status>::Type& transition_rates()
     165             :     {
     166       37187 :         return m_model->parameters.template get<smm::TransitionRates<Status>>();
     167             :     }
     168             : 
     169             :     /**
     170             :      * @brief Returns the model's adoption rates.
     171             :      */
     172       55425 :     inline constexpr const typename smm::AdoptionRates<Status>::Type& adoption_rates()
     173             :     {
     174       55425 :         return m_model->parameters.template get<smm::AdoptionRates<Status>>();
     175             :     }
     176             : 
     177             :     /**
     178             :      * @brief Calculate current values for m_current_rates and m_waiting_times.
     179             :      */
     180       18339 :     inline void update_current_rates_and_waiting_times()
     181             :     {
     182       18339 :         size_t i = 0; // shared index for iterating both rates
     183       18358 :         for (const auto& rate : adoption_rates()) {
     184          19 :             m_current_rates[i] = m_model->evaluate(rate, m_result.get_last_value());
     185          38 :             m_waiting_times[i] = (m_current_rates[i] > 0)
     186          19 :                                      ? (m_tp_next_event[i] - m_internal_time[i]) / m_current_rates[i]
     187          16 :                                      : std::numeric_limits<ScalarType>::max();
     188          19 :             i++;
     189             :         }
     190       36678 :         for (const auto& rate : transition_rates()) {
     191       18339 :             m_current_rates[i] = m_model->evaluate(rate, m_result.get_last_value());
     192       36678 :             m_waiting_times[i] = (m_current_rates[i] > 0)
     193       18339 :                                      ? (m_tp_next_event[i] - m_internal_time[i]) / m_current_rates[i]
     194           2 :                                      : std::numeric_limits<ScalarType>::max();
     195       18339 :             i++;
     196             :         }
     197       18339 :     }
     198             : 
     199             :     /**
     200             :      * @brief Get next event i.e. event with the smallest waiting time.
     201             :      */
     202       18339 :     inline size_t determine_next_event()
     203             :     {
     204       18339 :         return std::distance(m_waiting_times.begin(), std::min_element(m_waiting_times.begin(), m_waiting_times.end()));
     205             :     }
     206             : 
     207             :     ScalarType m_dt; ///< Initial step size
     208             :     std::unique_ptr<Model> m_model; ///< Pointer to the model used in the simulation.
     209             :     mio::TimeSeries<ScalarType> m_result; ///< Result time series.
     210             : 
     211             :     std::vector<ScalarType> m_internal_time; ///< Internal times of all poisson processes (aka T_k).
     212             :     std::vector<ScalarType> m_tp_next_event; ///< Internal time points of next event i after m_internal[i] (aka P_k).
     213             :     std::vector<ScalarType> m_waiting_times; ///< External times between m_internal_time and m_tp_next_event.
     214             :     std::vector<ScalarType>
     215             :         m_current_rates; ///< Current values of both types of rates i.e. adoption and transition rates.
     216             : };
     217             : 
     218             : } //namespace smm
     219             : } // namespace mio
     220             : 
     221             : #endif

Generated by: LCOV version 1.14