diff --git a/include/musica/micm/micm.hpp b/include/musica/micm/micm.hpp index f228d245f..d06cf7d7d 100644 --- a/include/musica/micm/micm.hpp +++ b/include/musica/micm/micm.hpp @@ -7,7 +7,6 @@ #include #include -#include #include #ifdef MUSICA_ENABLE_CUDA @@ -45,7 +44,7 @@ namespace class MusicaErrorCategory : public std::error_category { public: - const char *name() const noexcept override + const char* name() const noexcept override { return MUSICA_ERROR_CATEGORY; } @@ -74,6 +73,7 @@ inline std::error_code make_error_code(MusicaErrCode e) namespace musica { + class State; // forward declaration to break circular include /// @brief Types of MICM solver enum MICMSolver { @@ -106,7 +106,8 @@ namespace musica public: SolverVariant solver_variant_; - MICM(const Chemistry &chemistry, MICMSolver solver_type); + MICM(const Chemistry& chemistry, MICMSolver solver_type); + MICM(std::string config_path, MICMSolver solver_type); MICM() = default; ~MICM() { @@ -116,7 +117,7 @@ namespace musica // cuda must clean all of its runtime resources // Otherwise, we risk the CudaRosenbrock destructor running after // the cuda runtime has closed - std::visit([](auto &solver) { solver.reset(); }, solver_variant_); + std::visit([](auto& solver) { solver.reset(); }, solver_variant_); micm::cuda::CudaStreamSingleton::GetInstance().CleanUp(); #endif } @@ -124,19 +125,19 @@ namespace musica /// @brief Solve the system /// @param state Pointer to state object /// @param time_step Time [s] to advance the state by - micm::SolverResult Solve(musica::State *state, double time_step); + micm::SolverResult Solve(musica::State* state, double time_step); /// @brief Get a property for a chemical species /// @param species_name Name of the species /// @param property_name Name of the property /// @return Value of the property template - T GetSpeciesProperty(const std::string &species_name, const std::string &property_name) + T GetSpeciesProperty(const std::string& species_name, const std::string& property_name) { - micm::System system = std::visit([](auto &solver) -> micm::System { return solver->GetSystem(); }, solver_variant_); - for (const auto &phase_species : system.gas_phase_.phase_species_) + micm::System system = std::visit([](auto& solver) -> micm::System { return solver->GetSystem(); }, solver_variant_); + for (const auto& phase_species : system.gas_phase_.phase_species_) { - const auto &species = phase_species.species_; + const auto& species = phase_species.species_; if (species.name_ == species_name) { return species.GetProperty(property_name); @@ -149,7 +150,7 @@ namespace musica /// @return Maximum number of grid cells std::size_t GetMaximumNumberOfGridCells() { - return std::visit([](auto &solver) { return solver->MaximumNumberOfGridCells(); }, solver_variant_); + return std::visit([](auto& solver) { return solver->MaximumNumberOfGridCells(); }, solver_variant_); } }; diff --git a/include/musica/micm/micm_c_interface.hpp b/include/musica/micm/micm_c_interface.hpp index a48ca636b..ffa31c233 100644 --- a/include/musica/micm/micm_c_interface.hpp +++ b/include/musica/micm/micm_c_interface.hpp @@ -94,6 +94,10 @@ namespace musica size_t GetMaximumNumberOfGridCells(MICM *micm); bool _IsCudaAvailable(Error *error); + + /// @brief Get the MUSICA vector size + /// @return The MUSICA vector size + std::size_t GetVectorSize(musica::MICMSolver); #ifdef __cplusplus } #endif diff --git a/include/musica/micm/state.hpp b/include/musica/micm/state.hpp index d01db608e..82c0b974f 100644 --- a/include/musica/micm/state.hpp +++ b/include/musica/micm/state.hpp @@ -72,6 +72,24 @@ namespace musica /// @param concentrations Vector of concentrations void SetOrderedConcentrations(const std::vector& concentrations); + /// @brief Set the concentrations from a map of species name to concentration vectors + /// @param input a mapping of species name to concentrations per grid cell + /// @param solver_type The solver type to use for ordering + void SetConcentrations(const std::map>& input, musica::MICMSolver solver_type); + + /// @brief Get the concentrations as a map of species name to concentration vectors + /// @return Map of species name to concentration vectors + std::map> GetConcentrations(musica::MICMSolver solver_type) const; + + /// @brief Set the rate constants from a map of species name to rate constant vectors + /// @param input a mapping of species name to rate constants per grid cell + /// @param solver_type The solver type to use for ordering + void SetRateConstants(const std::map>& input, musica::MICMSolver solver_type); + + /// @brief Get the rate constants as a map of species name to rate constant vectors + /// @return Map of species name to rate constant vectors + std::map> GetRateConstants(musica::MICMSolver solver_type) const; + /// @brief Get the vector of rate constants /// @return Vector of doubles std::vector& GetOrderedRateParameters(); diff --git a/javascript/CMakeLists.txt b/javascript/CMakeLists.txt index 2862fa4d3..0ddbd1d3d 100644 --- a/javascript/CMakeLists.txt +++ b/javascript/CMakeLists.txt @@ -1,16 +1,12 @@ project(musica-wasm) -# Set C++ standard (required for MUSICA headers) -set(CMAKE_CXX_STANDARD 20) -set(CMAKE_CXX_STANDARD_REQUIRED ON) - # Building WASM module add_executable(musica-wasm src/musica_wasm.cpp - src/micm/state_wrapper.cpp - src/micm/micm_wrapper.cpp ) +target_compile_features(musica-wasm PUBLIC cxx_std_20) + # Link against MUSICA target_link_libraries(musica-wasm musica::musica) diff --git a/javascript/micm/index.js b/javascript/micm/index.js index 33042f2b0..5ae6348c6 100644 --- a/javascript/micm/index.js +++ b/javascript/micm/index.js @@ -3,6 +3,6 @@ export { MICM } from './micm.js'; export { State } from './state.js'; export { Conditions } from './conditions.js'; -export { SolverType } from './solver.js'; +export { SolverType, toWasmSolverType } from './solver.js'; export { SolverState, SolverStats, SolverResult } from './solver_result.js'; export { GAS_CONSTANT, AVOGADRO, BOLTZMANN } from './utils.js'; diff --git a/javascript/micm/micm.js b/javascript/micm/micm.js index 6c9f5b39d..0cd840d5b 100644 --- a/javascript/micm/micm.js +++ b/javascript/micm/micm.js @@ -1,5 +1,5 @@ import { State } from './state.js'; -import { SolverType } from './solver.js'; +import { SolverType, toWasmSolverType } from './solver.js'; import { SolverStats, SolverResult } from './solver_result.js'; import { getBackend } from '../backend.js'; @@ -18,6 +18,7 @@ export class MICM { try { const backend = getBackend(); + const wasmSolver = toWasmSolverType(solverType); // In Node.js with NODEFS mounted, configuration files are exposed under /host. // In browser environments, this prefix is invalid, so only add it when running under Node. let resolvedConfigPath = configPath; @@ -28,7 +29,7 @@ export class MICM { if (isNodeEnv && !configPath.startsWith('/host/')) { resolvedConfigPath = `/host/${configPath}`; } - const nativeMICM = backend.MICM.fromConfigPath(resolvedConfigPath, solverType); + const nativeMICM = backend.MICM.fromConfigPath(resolvedConfigPath, wasmSolver); return new MICM(nativeMICM, solverType); } catch (error) { throw new Error(`Failed to create MICM solver from config path: ${error.message}`); @@ -49,10 +50,11 @@ export class MICM { try { const backend = getBackend(); + const wasmSolver = toWasmSolverType(solverType); const mechanismJSON = mechanism.getJSON(); const jsonString = JSON.stringify(mechanismJSON); - const nativeMICM = backend.MICM.fromConfigString(jsonString, solverType); + const nativeMICM = backend.MICM.fromConfigString(jsonString, wasmSolver); return new MICM(nativeMICM, solverType); } catch (error) { throw new Error(`Failed to create MICM solver from mechanism: ${error.message}`); @@ -76,8 +78,7 @@ export class MICM { if (numberOfGridCells <= 0) { throw new RangeError('number_of_grid_cells must be greater than 0'); } - const nativeState = this._nativeMICM.createState(numberOfGridCells); - return new State(nativeState); + return new State(this._nativeMICM, numberOfGridCells, this._solverType); } solve(state, timeStep) { diff --git a/javascript/micm/solver.js b/javascript/micm/solver.js index d255d1dd2..c47915901 100644 --- a/javascript/micm/solver.js +++ b/javascript/micm/solver.js @@ -1,6 +1,51 @@ +import { getBackend } from '../backend.js'; + +/** + * Enum for solver types + * @readonly + * @enum {number} + */ export const SolverType = { rosenbrock: 1, // Vector-ordered Rosenbrock solver rosenbrock_standard_order: 2, // Standard-ordered Rosenbrock solver backward_euler: 3, // Vector-ordered BackwardEuler solver backward_euler_standard_order: 4, // Standard-ordered BackwardEuler solver -}; \ No newline at end of file +}; + +/** + * Converts a solver type to the WASM SolverType enum. + * Accepts either a number (1-5) or a WASM enum value. + * Throws if input is invalid. + * + * @param {number|any} solverType - Numeric solver type or WASM enum + * @returns {any} WASM SolverType enum value + */ +export function toWasmSolverType(solverType) { + if (solverType === undefined || solverType === null) { + throw new TypeError('solverType is required'); + } + + const backend = getBackend(); + const WASMEnum = backend.SolverType; + + if (typeof solverType === 'number') { + switch (solverType) { + case 1: return WASMEnum.Rosenbrock; + case 2: return WASMEnum.RosenbrockStandardOrder; + case 3: return WASMEnum.BackwardEuler; + case 4: return WASMEnum.BackwardEulerStandardOrder; + case 5: return WASMEnum.CudaRosenbrock; + default: + throw new RangeError(`Invalid numeric solverType: ${solverType}`); + } + } + + // If it’s already one of the WASM enum values, accept it + for (const key of Object.keys(WASMEnum)) { + if (solverType === WASMEnum[key]) { + return solverType; + } + } + + throw new TypeError('solverType must be a valid WASM MICMSolver enum or number'); +} \ No newline at end of file diff --git a/javascript/micm/state.js b/javascript/micm/state.js index 890f3b069..b63512155 100644 --- a/javascript/micm/state.js +++ b/javascript/micm/state.js @@ -1,21 +1,34 @@ import { isScalarNumber } from './utils.js'; +import { getBackend } from '../backend.js'; +import { toWasmSolverType } from './solver.js'; +import { GAS_CONSTANT } from './utils.js'; export class State { - constructor(nativeState) { - this._nativeState = nativeState; + constructor(nativeMICM, numberOfGridCells, solverType) { + if (!nativeMICM) { + throw new TypeError('nativeMICM is required'); + } + if (numberOfGridCells < 1) { + throw new RangeError('number_of_grid_cells must be greater than 0'); + } + + const backend = getBackend(); + this._nativeState = backend.create_state(nativeMICM, numberOfGridCells); + + this._numberOfGridCells = numberOfGridCells; + this._solverType = toWasmSolverType(solverType); } setConcentrations(concentrations) { - // Convert to format expected by WASM const formatted = {}; for (const [name, value] of Object.entries(concentrations)) { formatted[name] = isScalarNumber(value) ? [value] : value; } - this._nativeState.setConcentrations(formatted); + this._nativeState.set_concentrations(formatted, this._solverType); } getConcentrations() { - return this._nativeState.getConcentrations(); + return this._nativeState.get_concentrations(this._solverType); } setUserDefinedRateParameters(params) { @@ -23,60 +36,53 @@ export class State { for (const [name, value] of Object.entries(params)) { formatted[name] = isScalarNumber(value) ? [value] : value; } - this._nativeState.setUserDefinedRateParameters(formatted); + this._nativeState.set_user_defined_constants(formatted, this._solverType); } getUserDefinedRateParameters() { - return this._nativeState.getUserDefinedRateParameters(); + return this._nativeState.get_user_defined_constants(this._solverType); } - setConditions({ - temperatures = null, - pressures = null, - air_densities = null, - } = {}) { - const cond = {}; + setConditions({ temperatures = null, pressures = null, airDensities = null } = {}) { + const backend = getBackend(); + const vec = new backend.VectorConditions(); - if (temperatures !== null) { - cond.temperatures = isScalarNumber(temperatures) - ? [temperatures] - : temperatures; - } - if (pressures !== null) { - cond.pressures = isScalarNumber(pressures) - ? [pressures] - : pressures; - } - if (air_densities !== null) { - cond.air_densities = isScalarNumber(air_densities) - ? [air_densities] - : air_densities; - } + const expand = (param) => { + if (param === null || param === undefined) { + return Array(this._numberOfGridCells).fill(null); + } else if (!Array.isArray(param)) { + if (this._numberOfGridCells > 1) { + throw new Error("Scalar input requires a single grid cell"); + } + return [param]; + } else if (param.length !== this._numberOfGridCells) { + throw new Error(`Array input must have length ${this._numberOfGridCells}`); + } + return param; + }; - this._nativeState.setConditions(cond); - } + const temps = expand(temperatures); + const pres = expand(pressures); + const dens = expand(airDensities); - getConditions() { - return this._nativeState.getConditions(); - } - - getSpeciesOrdering() { - return this._nativeState.getSpeciesOrdering(); - } + for (let i = 0; i < this._numberOfGridCells; i++) { + const T = temps[i] !== null ? temps[i] : NaN; + const P = pres[i] !== null ? pres[i] : NaN; + let rho = dens[i] !== null ? dens[i] : (!Number.isNaN(T) && !Number.isNaN(P) ? P / (GAS_CONSTANT * T) : NaN); + - getUserDefinedRateParametersOrdering() { - return this._nativeState.getUserDefinedRateParametersOrdering(); - } + const cond = new backend.Condition(T, P, rho); + vec.push_back(cond); + } - getNumberOfGridCells() { - return this._nativeState.getNumberOfGridCells(); + this._nativeState.set_conditions(vec); } - concentrationStrides() { - return this._nativeState.concentrationStrides(); + getConditions() { + return this._nativeState.get_conditions(); } - userDefinedRateParameterStrides() { - return this._nativeState.userDefinedRateParameterStrides(); + getNumberOfGridCells() { + return this._numberOfGridCells; } -} \ No newline at end of file +} diff --git a/javascript/src/micm/micm_wrapper.cpp b/javascript/src/micm/micm_wrapper.cpp deleted file mode 100644 index 9a8ce8d14..000000000 --- a/javascript/src/micm/micm_wrapper.cpp +++ /dev/null @@ -1,112 +0,0 @@ -#include "micm_wrapper.h" - -#include "state_wrapper.h" - -#include -#include -#include -#include - -// Include MUSICA headers for real functionality -#include -#include -#include -#include -#include - -#include - -namespace musica_addon -{ - - // ============================================================================ - // MICMWrapper Implementation - // ============================================================================ - - MICMWrapper::MICMWrapper(musica::MICM* micm, int solver_type) - : micm_(micm), - solver_type_(solver_type) - { - } - - std::unique_ptr MICMWrapper::FromConfigPath(const std::string& config_path, int solver_type) - { - musica::Error error; - musica::MICM* micm = musica::CreateMicm(config_path.c_str(), static_cast(solver_type), &error); - - if (!musica::IsSuccess(error)) - { - std::string error_msg = "Failed to create MICM solver: "; - if (error.message_.value_ != nullptr) - { - error_msg += error.message_.value_; - musica::DeleteString(&error.message_); - } - musica::DeleteError(&error); - throw std::runtime_error(error_msg); - } - musica::DeleteError(&error); - return std::unique_ptr(new MICMWrapper(micm, solver_type)); - } - - std::unique_ptr MICMWrapper::FromConfigString(const std::string& config_string, int solver_type) - { - musica::Error error; - musica::MICM* micm = musica::CreateMicmFromConfigString(config_string.c_str(), static_cast(solver_type), &error); - - if (!musica::IsSuccess(error)) - { - std::string error_msg = "Failed to create MICM solver from config string: "; - if (error.message_.value_ != nullptr) - { - error_msg += error.message_.value_; - musica::DeleteString(&error.message_); - } - musica::DeleteError(&error); - throw std::runtime_error(error_msg); - } - musica::DeleteError(&error); - return std::unique_ptr(new MICMWrapper(micm, solver_type)); - } - - MICMWrapper::~MICMWrapper() - { - if (micm_ != nullptr) - { - musica::Error error; - musica::DeleteMicm(micm_, &error); - musica::DeleteError(&error); - } - } - - musica::State* MICMWrapper::CreateState(size_t number_of_grid_cells) - { - musica::Error error; - musica::State* state = musica::CreateMicmState(micm_, number_of_grid_cells, &error); - - if (!musica::IsSuccess(error)) - { - std::string error_msg = "Failed to create state: "; - if (error.message_.value_ != nullptr) - { - error_msg += error.message_.value_; - musica::DeleteString(&error.message_); - } - musica::DeleteError(&error); - throw std::runtime_error(error_msg); - } - musica::DeleteError(&error); - return state; - } - - micm::SolverResult MICMWrapper::Solve(musica::State* state, double time_step) - { - return micm_->Solve(state, time_step); - } - - int MICMWrapper::GetSolverType() const - { - return solver_type_; - } - -} // namespace musica_addon diff --git a/javascript/src/micm/micm_wrapper.h b/javascript/src/micm/micm_wrapper.h deleted file mode 100644 index 0405a3826..000000000 --- a/javascript/src/micm/micm_wrapper.h +++ /dev/null @@ -1,43 +0,0 @@ -#pragma once - -#include "state_wrapper.h" - -#include - -#include - -#include -#include -#include -#include - -// Forward declarations of MUSICA types -namespace musica -{ - class State; -} - -namespace musica_addon -{ - - /// @brief C++ wrapper for MICM solver - class MICMWrapper - { - public: - ~MICMWrapper(); - - // Static factory methods - static std::unique_ptr FromConfigPath(const std::string& config_path, int solver_type); - static std::unique_ptr FromConfigString(const std::string& config_string, int solver_type); - - musica::State* CreateState(size_t number_of_grid_cells); - micm::SolverResult Solve(musica::State* state, double time_step); - int GetSolverType() const; - - private: - MICMWrapper(musica::MICM* micm, int solver_type); - musica::MICM* micm_; - int solver_type_; - }; - -} // namespace musica_addon \ No newline at end of file diff --git a/javascript/src/micm/state_wrapper.cpp b/javascript/src/micm/state_wrapper.cpp deleted file mode 100644 index 586f299f7..000000000 --- a/javascript/src/micm/state_wrapper.cpp +++ /dev/null @@ -1,388 +0,0 @@ -#include "state_wrapper.h" - -#include "micm_wrapper.h" - -#include -#include -#include -#include -#include - -// Include MUSICA headers for real functionality -#include -#include -#include -#include -#include -#include - -#include - -namespace musica_addon -{ - - // ============================================================================ - // StateWrapper Implementation - // ============================================================================ - - void StateDeleter::operator()(musica::State* state) const - { - if (state != nullptr) - { - musica::Error error; - musica::DeleteState(state, &error); - musica::DeleteError(&error); - } - } - - StateWrapper::StateWrapper(musica::State* state) - : state_(state, StateDeleter()) - { - } - - void StateWrapper::SetConcentrations(const std::map>& concentrations) - { - musica::Error error; - - // Get species ordering - musica::Mappings species_ordering; - musica::GetSpeciesOrdering(state_.get(), &species_ordering, &error); - if (!musica::IsSuccess(error)) - { - musica::DeleteError(&error); - throw std::runtime_error("Failed to get species ordering"); - } - - // Get concentration pointer and strides - size_t array_size; - double* conc_ptr = musica::GetOrderedConcentrationsPointer(state_.get(), &array_size, &error); - if (!musica::IsSuccess(error)) - { - musica::DeleteMappings(&species_ordering); - musica::DeleteError(&error); - throw std::runtime_error("Failed to get concentrations pointer"); - } - - size_t cell_stride, species_stride; - musica::GetConcentrationsStrides(state_.get(), &error, &cell_stride, &species_stride); - - size_t num_cells = musica::GetNumberOfGridCells(state_.get(), &error); - - // Set concentrations - for (size_t i = 0; i < species_ordering.size_; ++i) - { - std::string species_name = species_ordering.mappings_[i].name_.value_; - size_t species_idx = species_ordering.mappings_[i].index_; - - auto it = concentrations.find(species_name); - if (it != concentrations.end()) - { - const auto& values = it->second; - for (size_t cell = 0; cell < num_cells && cell < values.size(); ++cell) - { - conc_ptr[species_idx * species_stride + cell * cell_stride] = values[cell]; - } - } - } - - musica::DeleteMappings(&species_ordering); - musica::DeleteError(&error); - } - - std::map> StateWrapper::GetConcentrations() - { - musica::Error error; - std::map> result; - - // Get species ordering - musica::Mappings species_ordering; - musica::GetSpeciesOrdering(state_.get(), &species_ordering, &error); - if (!musica::IsSuccess(error)) - { - musica::DeleteError(&error); - throw std::runtime_error("Failed to get species ordering"); - } - - // Get concentration pointer and strides - size_t array_size; - double* conc_ptr = musica::GetOrderedConcentrationsPointer(state_.get(), &array_size, &error); - if (!musica::IsSuccess(error)) - { - musica::DeleteMappings(&species_ordering); - musica::DeleteError(&error); - throw std::runtime_error("Failed to get concentrations pointer"); - } - - size_t cell_stride, species_stride; - musica::GetConcentrationsStrides(state_.get(), &error, &cell_stride, &species_stride); - - size_t num_cells = musica::GetNumberOfGridCells(state_.get(), &error); - - // Get concentrations - for (size_t i = 0; i < species_ordering.size_; ++i) - { - std::string species_name = species_ordering.mappings_[i].name_.value_; - size_t species_idx = species_ordering.mappings_[i].index_; - - std::vector values(num_cells); - for (size_t cell = 0; cell < num_cells; ++cell) - { - values[cell] = conc_ptr[species_idx * species_stride + cell * cell_stride]; - } - result[species_name] = values; - } - - musica::DeleteMappings(&species_ordering); - musica::DeleteError(&error); - return result; - } - - void StateWrapper::SetUserDefinedRateParameters(const std::map>& params) - { - musica::Error error; - - // Get rate parameters ordering - musica::Mappings params_ordering; - musica::GetUserDefinedRateParametersOrdering(state_.get(), ¶ms_ordering, &error); - if (!musica::IsSuccess(error)) - { - musica::DeleteError(&error); - throw std::runtime_error("Failed to get user-defined rate parameters ordering"); - } - - // Get rate parameters pointer and strides - size_t array_size; - double* params_ptr = musica::GetOrderedRateParametersPointer(state_.get(), &array_size, &error); - if (!musica::IsSuccess(error)) - { - musica::DeleteMappings(¶ms_ordering); - musica::DeleteError(&error); - throw std::runtime_error("Failed to get rate parameters pointer"); - } - - size_t cell_stride, param_stride; - musica::GetUserDefinedRateParametersStrides(state_.get(), &error, &cell_stride, ¶m_stride); - - size_t num_cells = musica::GetNumberOfGridCells(state_.get(), &error); - - // Set parameters - for (size_t i = 0; i < params_ordering.size_; ++i) - { - std::string param_name = params_ordering.mappings_[i].name_.value_; - size_t param_idx = params_ordering.mappings_[i].index_; - - auto it = params.find(param_name); - if (it != params.end()) - { - const auto& values = it->second; - for (size_t cell = 0; cell < num_cells && cell < values.size(); ++cell) - { - params_ptr[param_idx * param_stride + cell * cell_stride] = values[cell]; - } - } - } - - musica::DeleteMappings(¶ms_ordering); - musica::DeleteError(&error); - } - - std::map> StateWrapper::GetUserDefinedRateParameters() - { - musica::Error error; - std::map> result; - - // Get rate parameters ordering - musica::Mappings params_ordering; - musica::GetUserDefinedRateParametersOrdering(state_.get(), ¶ms_ordering, &error); - if (!musica::IsSuccess(error)) - { - musica::DeleteError(&error); - throw std::runtime_error("Failed to get user-defined rate parameters ordering"); - } - - // Get rate parameters pointer and strides - size_t array_size; - double* params_ptr = musica::GetOrderedRateParametersPointer(state_.get(), &array_size, &error); - if (!musica::IsSuccess(error)) - { - musica::DeleteMappings(¶ms_ordering); - musica::DeleteError(&error); - throw std::runtime_error("Failed to get rate parameters pointer"); - } - - size_t cell_stride, param_stride; - musica::GetUserDefinedRateParametersStrides(state_.get(), &error, &cell_stride, ¶m_stride); - - size_t num_cells = musica::GetNumberOfGridCells(state_.get(), &error); - - // Get parameters - for (size_t i = 0; i < params_ordering.size_; ++i) - { - std::string param_name = params_ordering.mappings_[i].name_.value_; - size_t param_idx = params_ordering.mappings_[i].index_; - - std::vector values(num_cells); - for (size_t cell = 0; cell < num_cells; ++cell) - { - values[cell] = params_ptr[param_idx * param_stride + cell * cell_stride]; - } - result[param_name] = values; - } - - musica::DeleteMappings(¶ms_ordering); - musica::DeleteError(&error); - return result; - } - - void StateWrapper::SetConditions( - const std::vector* temperatures, - const std::vector* pressures, - const std::vector* air_densities) - { - musica::Error error; - size_t array_size; - micm::Conditions* conditions = musica::GetConditionsPointer(state_.get(), &array_size, &error); - - if (!musica::IsSuccess(error)) - { - musica::DeleteError(&error); - throw std::runtime_error("Failed to get conditions pointer"); - } - - size_t num_cells = musica::GetNumberOfGridCells(state_.get(), &error); - - for (size_t i = 0; i < num_cells; ++i) - { - if (temperatures && i < temperatures->size()) - { - conditions[i].temperature_ = (*temperatures)[i]; - } - if (pressures && i < pressures->size()) - { - conditions[i].pressure_ = (*pressures)[i]; - } - if (air_densities && i < air_densities->size()) - { - conditions[i].air_density_ = (*air_densities)[i]; - } - else if (temperatures && pressures && i < temperatures->size() && i < pressures->size()) - { - // Calculate air density from ideal gas law if not provided - constexpr double GAS_CONSTANT = 8.31446261815324; // J K^-1 mol^-1 - conditions[i].air_density_ = (*pressures)[i] / (GAS_CONSTANT * (*temperatures)[i]); - } - } - - musica::DeleteError(&error); - } - - std::map> StateWrapper::GetConditions() - { - musica::Error error; - std::map> result; - - size_t array_size; - micm::Conditions* conditions = musica::GetConditionsPointer(state_.get(), &array_size, &error); - - if (!musica::IsSuccess(error)) - { - musica::DeleteError(&error); - throw std::runtime_error("Failed to get conditions pointer"); - } - - size_t num_cells = musica::GetNumberOfGridCells(state_.get(), &error); - - result["temperature"] = std::vector(num_cells); - result["pressure"] = std::vector(num_cells); - result["air_density"] = std::vector(num_cells); - - for (size_t i = 0; i < num_cells; ++i) - { - result["temperature"][i] = conditions[i].temperature_; - result["pressure"][i] = conditions[i].pressure_; - result["air_density"][i] = conditions[i].air_density_; - } - - musica::DeleteError(&error); - return result; - } - - std::map StateWrapper::GetSpeciesOrdering() - { - musica::Error error; - std::map result; - - musica::Mappings species_ordering; - musica::GetSpeciesOrdering(state_.get(), &species_ordering, &error); - - if (!musica::IsSuccess(error)) - { - musica::DeleteError(&error); - throw std::runtime_error("Failed to get species ordering"); - } - - for (size_t i = 0; i < species_ordering.size_; ++i) - { - result[species_ordering.mappings_[i].name_.value_] = species_ordering.mappings_[i].index_; - } - - musica::DeleteMappings(&species_ordering); - musica::DeleteError(&error); - return result; - } - - std::map StateWrapper::GetUserDefinedRateParametersOrdering() - { - musica::Error error; - std::map result; - - musica::Mappings params_ordering; - musica::GetUserDefinedRateParametersOrdering(state_.get(), ¶ms_ordering, &error); - - if (!musica::IsSuccess(error)) - { - musica::DeleteError(&error); - throw std::runtime_error("Failed to get user-defined rate parameters ordering"); - } - - for (size_t i = 0; i < params_ordering.size_; ++i) - { - result[params_ordering.mappings_[i].name_.value_] = params_ordering.mappings_[i].index_; - } - - musica::DeleteMappings(¶ms_ordering); - musica::DeleteError(&error); - return result; - } - - void StateWrapper::GetConcentrationStrides(size_t& cell_stride, size_t& species_stride) - { - musica::Error error; - musica::GetConcentrationsStrides(state_.get(), &error, &cell_stride, &species_stride); - musica::DeleteError(&error); - } - - void StateWrapper::GetUserDefinedRateParameterStrides(size_t& cell_stride, size_t& param_stride) - { - musica::Error error; - musica::GetUserDefinedRateParametersStrides(state_.get(), &error, &cell_stride, ¶m_stride); - musica::DeleteError(&error); - } - - size_t StateWrapper::GetNumberOfGridCells() - { - musica::Error error; - size_t num_cells = musica::GetNumberOfGridCells(state_.get(), &error); - musica::DeleteError(&error); - return num_cells; - } - - double* StateWrapper::GetConcentrationsPointer(size_t& array_size) - { - musica::Error error; - double* ptr = musica::GetOrderedConcentrationsPointer(state_.get(), &array_size, &error); - musica::DeleteError(&error); - return ptr; - } - -} // namespace musica_addon \ No newline at end of file diff --git a/javascript/src/micm/state_wrapper.h b/javascript/src/micm/state_wrapper.h deleted file mode 100644 index 35022b7a5..000000000 --- a/javascript/src/micm/state_wrapper.h +++ /dev/null @@ -1,58 +0,0 @@ -#pragma once - -#include -#include -#include -#include - -// Forward declarations of MUSICA types -namespace musica -{ - class State; -} - -namespace musica_addon -{ - - // Custom deleter for musica::State - struct StateDeleter - { - void operator()(musica::State* state) const; - }; - - /// @brief C++ wrapper for MICM state - class StateWrapper - { - public: - StateWrapper(musica::State* state); - ~StateWrapper() = default; - - void SetConcentrations(const std::map>& concentrations); - std::map> GetConcentrations(); - void SetUserDefinedRateParameters(const std::map>& params); - std::map> GetUserDefinedRateParameters(); - void SetConditions( - const std::vector* temperatures, - const std::vector* pressures, - const std::vector* air_densities); - std::map> GetConditions(); - - std::map GetSpeciesOrdering(); - std::map GetUserDefinedRateParametersOrdering(); - - void GetConcentrationStrides(size_t& cell_stride, size_t& species_stride); - void GetUserDefinedRateParameterStrides(size_t& cell_stride, size_t& param_stride); - - size_t GetNumberOfGridCells(); - - double* GetConcentrationsPointer(size_t& array_size); - musica::State* GetState() const - { - return state_.get(); - } - - private: - std::unique_ptr state_; - }; - -} // namespace musica_addon diff --git a/javascript/src/musica_wasm.cpp b/javascript/src/musica_wasm.cpp index dd124dbce..433f4be9b 100644 --- a/javascript/src/musica_wasm.cpp +++ b/javascript/src/musica_wasm.cpp @@ -3,9 +3,10 @@ // // WASM bindings for MUSICA using Emscripten -#include "micm/micm_wrapper.h" -#include "micm/state_wrapper.h" - +#include +#include +#include +#include #include #include @@ -18,301 +19,204 @@ #include using namespace emscripten; -using namespace musica_addon; - -// Wrapper functions to return std::string instead of const char* -std::string GetVersion() -{ - return std::string(musica::GetMusicaVersion()); -} - -std::string GetMicmVersion() -{ - return std::string(micm::GetMicmVersion()); -} - -// Helper function to convert JavaScript array to C++ vector -template -std::vector jsArrayToVector(const val& jsArray) -{ - std::vector vec; - unsigned int length = jsArray["length"].as(); - vec.reserve(length); - for (unsigned int i = 0; i < length; ++i) - { - vec.push_back(jsArray[i].as()); - } - return vec; -} - -// ============================================================================ -// StateWrapper bindings -// ============================================================================ - -class StateWrapperWASM -{ - public: - // Constructor - takes ownership of the state pointer via StateWrapper - explicit StateWrapperWASM(musica::State* state) - : wrapper_(std::make_unique(state)) - { - } - - // Move constructor and assignment - StateWrapperWASM(StateWrapperWASM&& other) = default; - StateWrapperWASM& operator=(StateWrapperWASM&& other) = default; - - // Delete copy constructor and assignment - StateWrapperWASM(const StateWrapperWASM&) = delete; - StateWrapperWASM& operator=(const StateWrapperWASM&) = delete; - - void setConcentrations(const val& concentrations) - { - std::map> conc_map; - // Convert JavaScript object to C++ map - auto keys = val::global("Object").call("keys", concentrations); - unsigned int length = keys["length"].as(); - for (unsigned int i = 0; i < length; ++i) - { - std::string key = keys[i].as(); - val value = concentrations[key]; - std::vector vec = jsArrayToVector(value); - conc_map[key] = vec; - } - wrapper_->SetConcentrations(conc_map); - } - - val getConcentrations() - { - auto conc_map = wrapper_->GetConcentrations(); - val result = val::object(); - for (const auto& pair : conc_map) - { - result.set(pair.first, val::array(pair.second.begin(), pair.second.end())); - } - return result; - } - - void setUserDefinedRateParameters(const val& params) - { - std::map> param_map; - auto keys = val::global("Object").call("keys", params); - unsigned int length = keys["length"].as(); - for (unsigned int i = 0; i < length; ++i) - { - std::string key = keys[i].as(); - val value = params[key]; - std::vector vec = jsArrayToVector(value); - param_map[key] = vec; - } - wrapper_->SetUserDefinedRateParameters(param_map); - } - - val getUserDefinedRateParameters() - { - auto param_map = wrapper_->GetUserDefinedRateParameters(); - val result = val::object(); - for (const auto& pair : param_map) - { - result.set(pair.first, val::array(pair.second.begin(), pair.second.end())); - } - return result; - } - - void setConditions(const val& conditions) - { - const std::vector* temperatures = nullptr; - const std::vector* pressures = nullptr; - const std::vector* air_densities = nullptr; - - std::vector temp_vec, press_vec, air_vec; - - if (conditions.hasOwnProperty("temperatures")) - { - temp_vec = jsArrayToVector(conditions["temperatures"]); - temperatures = &temp_vec; - } - if (conditions.hasOwnProperty("pressures")) - { - press_vec = jsArrayToVector(conditions["pressures"]); - pressures = &press_vec; - } - if (conditions.hasOwnProperty("air_densities")) - { - air_vec = jsArrayToVector(conditions["air_densities"]); - air_densities = &air_vec; - } - - wrapper_->SetConditions(temperatures, pressures, air_densities); - } - - val getConditions() - { - auto cond_map = wrapper_->GetConditions(); - val result = val::object(); - for (const auto& pair : cond_map) - { - result.set(pair.first, val::array(pair.second.begin(), pair.second.end())); - } - return result; - } - - val getSpeciesOrdering() - { - auto ordering = wrapper_->GetSpeciesOrdering(); - val result = val::object(); - for (const auto& pair : ordering) - { - result.set(pair.first, val(pair.second)); - } - return result; - } - - val getUserDefinedRateParametersOrdering() - { - auto ordering = wrapper_->GetUserDefinedRateParametersOrdering(); - val result = val::object(); - for (const auto& pair : ordering) - { - result.set(pair.first, val(pair.second)); - } - return result; - } - - size_t getNumberOfGridCells() - { - return wrapper_->GetNumberOfGridCells(); - } - - val concentrationStrides() - { - size_t cell_stride, species_stride; - wrapper_->GetConcentrationStrides(cell_stride, species_stride); - val result = val::object(); - result.set("cell_stride", val(cell_stride)); - result.set("species_stride", val(species_stride)); - return result; - } - - val userDefinedRateParameterStrides() - { - size_t cell_stride, param_stride; - wrapper_->GetUserDefinedRateParameterStrides(cell_stride, param_stride); - val result = val::object(); - result.set("cell_stride", val(cell_stride)); - result.set("param_stride", val(param_stride)); - return result; - } - - // Make wrapper accessible for MICMWrapperWASM - StateWrapper& getWrapper() - { - return *wrapper_; - } - - private: - std::unique_ptr wrapper_; -}; - -// ============================================================================ -// MICMWrapper bindings -// ============================================================================ - -class MICMWrapperWASM -{ - public: - // Constructor - explicit MICMWrapperWASM(std::unique_ptr wrapper) - : wrapper_(std::move(wrapper)) - { - } - - // Move constructor and assignment - MICMWrapperWASM(MICMWrapperWASM&& other) = default; - MICMWrapperWASM& operator=(MICMWrapperWASM&& other) = default; - - // Delete copy constructor and assignment - MICMWrapperWASM(const MICMWrapperWASM&) = delete; - MICMWrapperWASM& operator=(const MICMWrapperWASM&) = delete; - - static std::shared_ptr fromConfigPath(const std::string& config_path, int solver_type) - { - auto wrapper = MICMWrapper::FromConfigPath(config_path, solver_type); - return std::make_shared(std::move(wrapper)); - } - - static std::shared_ptr fromConfigString(const std::string& config_string, int solver_type) - { - auto wrapper = MICMWrapper::FromConfigString(config_string, solver_type); - return std::make_shared(std::move(wrapper)); - } - - std::shared_ptr createState(size_t number_of_grid_cells) - { - musica::State* state = wrapper_->CreateState(number_of_grid_cells); - return std::make_shared(state); - } - - val solve(StateWrapperWASM& state, double time_step) - { - // Note: We need to extract the raw state pointer to pass to the underlying - // MICM solver. The StateWrapperWASM maintains ownership of the state. - // This is safe as long as the state outlives this solve call. - musica::State* raw_state = state.getWrapper().GetState(); - auto result = wrapper_->Solve(raw_state, time_step); - - // Convert SolverResult to JavaScript object - val js_result = val::object(); - js_result.set("state", val(static_cast(result.state_))); - - val stats = val::object(); - stats.set("function_calls", val(result.stats_.function_calls_)); - stats.set("jacobian_updates", val(result.stats_.jacobian_updates_)); - stats.set("number_of_steps", val(result.stats_.number_of_steps_)); - stats.set("accepted", val(result.stats_.accepted_)); - stats.set("rejected", val(result.stats_.rejected_)); - stats.set("decompositions", val(result.stats_.decompositions_)); - stats.set("solves", val(result.stats_.solves_)); - stats.set("final_time", val(result.stats_.final_time_)); - - js_result.set("stats", stats); - - return js_result; - } - - int solverType() const - { - return wrapper_->GetSolverType(); - } - - private: - std::unique_ptr wrapper_; -}; EMSCRIPTEN_BINDINGS(musica_module) { - function("getVersion", &GetVersion); - function("getMicmVersion", &GetMicmVersion); - - class_("State") - .smart_ptr>("StatePtr") - .function("setConcentrations", &StateWrapperWASM::setConcentrations) - .function("getConcentrations", &StateWrapperWASM::getConcentrations) - .function("setUserDefinedRateParameters", &StateWrapperWASM::setUserDefinedRateParameters) - .function("getUserDefinedRateParameters", &StateWrapperWASM::getUserDefinedRateParameters) - .function("setConditions", &StateWrapperWASM::setConditions) - .function("getConditions", &StateWrapperWASM::getConditions) - .function("getSpeciesOrdering", &StateWrapperWASM::getSpeciesOrdering) - .function("getUserDefinedRateParametersOrdering", &StateWrapperWASM::getUserDefinedRateParametersOrdering) - .function("getNumberOfGridCells", &StateWrapperWASM::getNumberOfGridCells) - .function("concentrationStrides", &StateWrapperWASM::concentrationStrides) - .function("userDefinedRateParameterStrides", &StateWrapperWASM::userDefinedRateParameterStrides); - - class_("MICM") - .smart_ptr>("MICMPtr") - .class_function("fromConfigPath", &MICMWrapperWASM::fromConfigPath) - .class_function("fromConfigString", &MICMWrapperWASM::fromConfigString) - .function("createState", &MICMWrapperWASM::createState) - .function("solve", &MICMWrapperWASM::solve) - .function("solverType", &MICMWrapperWASM::solverType); + function("getVersion", optional_override([]() { return std::string(musica::GetMusicaVersion()); })); + + function("getMicmVersion", optional_override([]() { return std::string(micm::GetMicmVersion()); })); + + class_("Condition") + .constructor<>() + .constructor() + .property("temperature", &micm::Conditions::temperature_) + .property("pressure", &micm::Conditions::pressure_) + .property("air_density", &micm::Conditions::air_density_); + + register_vector("VectorDouble"); + register_vector("VectorConditions"); + register_map("MapStringSizeT"); + register_map>("MapStringVectorDouble"); + + enum_("SolverType") + .value("Rosenbrock", musica::MICMSolver::Rosenbrock) + .value("BackwardEuler", musica::MICMSolver::BackwardEuler) + .value("CudaRosenbrock", musica::MICMSolver::CudaRosenbrock) + .value("RosenbrockStandardOrder", musica::MICMSolver::RosenbrockStandardOrder) + .value("BackwardEulerStandardOrder", musica::MICMSolver::BackwardEulerStandardOrder); + + enum_("SolverState") + .value("NotYetCalled", micm::SolverState::NotYetCalled) + .value("Running", micm::SolverState::Running) + .value("Converged", micm::SolverState::Converged) + .value("ConvergenceExceededMaxSteps", micm::SolverState::ConvergenceExceededMaxSteps) + .value("StepSizeTooSmall", micm::SolverState::StepSizeTooSmall) + .value("RepeatedlySingularMatrix", micm::SolverState::RepeatedlySingularMatrix) + .value("NaNDetected", micm::SolverState::NaNDetected) + .value("InfDetected", micm::SolverState::InfDetected) + .value("AcceptingUnconvergedIntegration", micm::SolverState::AcceptingUnconvergedIntegration); + + value_object("SolverResultsStats") + .field("function_calls", &musica::SolverResultStats::function_calls_) + .field("jacobian_updates", &musica::SolverResultStats::jacobian_updates_) + .field("number_of_steps", &musica::SolverResultStats::number_of_steps_) + .field("accepted", &musica::SolverResultStats::accepted_) + .field("rejected", &musica::SolverResultStats::rejected_) + .field("decompositions", &musica::SolverResultStats::decompositions_) + .field("solves", &musica::SolverResultStats::solves_) + .field("final_time", &musica::SolverResultStats::final_time_); + + class_("SolverResult") + .constructor() + .property( + "state", + optional_override( + [](const micm::SolverResult& r) + { + return static_cast(r.state_); + })) + .property("stats", &micm::SolverResult::stats_); + + class_("State") + .smart_ptr>("State") + .function("number_of_grid_cells", &musica::State::NumberOfGridCells) + .function("set_conditions", &musica::State::SetConditions) + .function( + "get_conditions", + optional_override( + [](std::shared_ptr state) + { + const std::vector& cppVec = state->GetConditions(); + emscripten::val result = emscripten::val::array(); + + for (size_t i = 0; i < cppVec.size(); ++i) + { + emscripten::val cond = emscripten::val::object(); + cond.set("temperature", cppVec[i].temperature_); + cond.set("pressure", cppVec[i].pressure_); + cond.set("air_density", cppVec[i].air_density_); + result.call("push", cond); + } + + return result; + })) + .function( + "get_concentrations", + optional_override( + [](std::shared_ptr state, musica::MICMSolver solver) + { + std::map> cppMap = state->GetConcentrations(solver); + emscripten::val result = emscripten::val::object(); + for (auto& [key, vec] : cppMap) + { + emscripten::val jsArray = emscripten::val::array(); + for (double v : vec) + { + jsArray.call("push", v); + } + result.set(key, jsArray); + } + return result; + })) + .function( + "set_concentrations", + optional_override( + [](std::shared_ptr state, emscripten::val input, musica::MICMSolver solver) + { + std::map> cppMap; + + emscripten::val keys = emscripten::val::global("Object").call("keys", input); + int len = keys["length"].as(); + for (int i = 0; i < len; ++i) + { + std::string key = keys[i].as(); + emscripten::val jsArray = input[key]; + cppMap[key] = emscripten::vecFromJSArray(jsArray); + } + + state->SetConcentrations(cppMap, solver); + })) + .function( + "set_user_defined_constants", + optional_override( + [](std::shared_ptr state, emscripten::val input, musica::MICMSolver solver) + { + std::map> cppMap; + + emscripten::val keys = emscripten::val::global("Object").call("keys", input); + int len = keys["length"].as(); + for (int i = 0; i < len; ++i) + { + std::string key = keys[i].as(); + emscripten::val jsArray = input[key]; + cppMap[key] = emscripten::vecFromJSArray(jsArray); + } + + state->SetRateConstants(cppMap, solver); + })) + .function( + "get_user_defined_constants", + optional_override( + [](std::shared_ptr state, musica::MICMSolver solver) + { + std::map> cppMap = state->GetRateConstants(solver); + emscripten::val result = emscripten::val::object(); + for (auto& [key, vec] : cppMap) + { + emscripten::val jsArray = emscripten::val::array(); + for (double v : vec) + { + jsArray.call("push", v); + } + result.set(key, jsArray); + } + return result; + })); + + class_("MICM") + .smart_ptr>("MICM") + .class_function( + "fromConfigPath", + optional_override([](std::string path, musica::MICMSolver solver) + { return std::make_unique(path, solver); })) + .class_function( + "fromConfigString", + optional_override( + [](std::string config_string, musica::MICMSolver solver) + { return std::make_unique(musica::ReadConfigurationFromString(config_string), solver); })) + .function( + "solve", + optional_override( + [](musica::MICM& micm, std::shared_ptr state, double dt) + { + return micm.Solve(state.get(), dt); // pass raw pointer internally + })) + .function("get_maximum_number_of_grid_cells", &musica::MICM::GetMaximumNumberOfGridCells); + + function("vector_size", musica::GetVectorSize); + + function( + "create_state", + optional_override([](musica::MICM& micm, std::size_t number_of_grid_cells) + { return std::make_shared(micm, number_of_grid_cells); })); + + function( + "species_ordering", + optional_override( + [](std::shared_ptr state) + { + std::map map; + std::visit([&map](auto& s) { map = s.variable_map_; }, state->state_variant_); + return map; + })); + + function( + "user_defined_rate_parameters_ordering", + optional_override( + [](std::shared_ptr state) + { + std::map map; + std::visit([&map](auto& s) { map = s.custom_rate_parameter_map_; }, state->state_variant_); + return map; + })); } diff --git a/javascript/tests/integration/analytical.test.js b/javascript/tests/integration/analytical.test.js index ba099290e..c4c5f0d6b 100644 --- a/javascript/tests/integration/analytical.test.js +++ b/javascript/tests/integration/analytical.test.js @@ -3,346 +3,88 @@ import assert from 'node:assert'; import path from 'path'; import * as musica from '../../index.js'; import { isClose } from '../util/testUtils.js'; - import { fileURLToPath } from 'url'; -// Convert import.meta.url to a file path const __filename = fileURLToPath(import.meta.url); const __dirname = path.dirname(__filename); const { MICM, SolverType, GAS_CONSTANT } = musica; - -// Test configuration const CONFIG_PATH = path.join(__dirname, '../../../configs/v0/analytical'); -// NOTE: Vector-ordered Rosenbrock currently only supports up to 4 grid cells -// This is because the C++ implementation requires splitting into multiple -// internal states for >4 cells (the vector size), which is not yet implemented -// in the JavaScript bindings. Python handles this by creating multiple states. -const maxCells = 4; // Vector size limitation - before(async () => { await musica.initModule(); }); -/** - * Test single grid cell analytical solution - * Equivalent to TestSingleGridCell in Python - */ -function testSingleGridCell(solver, state, timeStep, places = 5) { +// Combined single/multiple grid cell test +function testAnalytical(solver, state, numCells = 1, step = 1, places = 5) { const temperature = 272.5; const pressure = 101253.3; const airDensity = pressure / (GAS_CONSTANT * temperature); - const rateConstants = { - 'USER.reaction 1': 0.001, - 'USER.reaction 2': 0.002 - }; - - const concentrations = { - A: 0.75, - B: 0, - C: 0.4, - D: 0.8, - E: 0, - F: 0.1 - }; - - state.setConditions({ temperatures: temperature, pressures: pressure, air_densities: airDensity }); - state.setConcentrations(concentrations); - state.setUserDefinedRateParameters(rateConstants); - - // Test to make sure a second call with empty dictionary does not change values - state.setConcentrations({}); - state.setUserDefinedRateParameters({}); - - const initialConcentrations = state.getConcentrations(); - const initialRateParameters = state.getUserDefinedRateParameters(); - const initialConditions = state.getConditions(); - - // Verify initial conditions - assert.ok(isClose(initialConcentrations.A[0], concentrations.A, 1e-13), 'Initial A concentration mismatch'); - assert.ok(isClose(initialConcentrations.B[0], concentrations.B, 1e-13), 'Initial B concentration mismatch'); - assert.ok(isClose(initialConcentrations.C[0], concentrations.C, 1e-13), 'Initial C concentration mismatch'); - assert.ok(isClose(initialConcentrations.D[0], concentrations.D, 1e-13), 'Initial D concentration mismatch'); - assert.ok(isClose(initialConcentrations.E[0], concentrations.E, 1e-13), 'Initial E concentration mismatch'); - assert.ok(isClose(initialConcentrations.F[0], concentrations.F, 1e-13), 'Initial F concentration mismatch'); - assert.ok(isClose(initialRateParameters['USER.reaction 1'][0], rateConstants['USER.reaction 1'], 1e-13), 'Rate parameter 1 mismatch'); - assert.ok(isClose(initialRateParameters['USER.reaction 2'][0], rateConstants['USER.reaction 2'], 1e-13), 'Rate parameter 2 mismatch'); - assert.ok(isClose(initialConditions.temperature[0], temperature, 1e-13), 'Temperature mismatch'); - assert.ok(isClose(initialConditions.pressure[0], pressure, 1e-13), 'Pressure mismatch'); - assert.ok(isClose(initialConditions.air_density[0], airDensity, 1e-13), 'Air density mismatch'); - - timeStep = 1; - const simLength = 100; - - let currTime = timeStep; - const initialA = initialConcentrations.A[0]; - const initialC = initialConcentrations.C[0]; - const initialD = initialConcentrations.D[0]; - const initialF = initialConcentrations.F[0]; - - const tolerance = Math.pow(10, -places); - - // Integrate and compare with analytical solution - while (currTime <= simLength) { - solver.solve(state, timeStep); - const conc = state.getConcentrations(); - - const k1 = rateConstants['USER.reaction 1']; - const k2 = rateConstants['USER.reaction 2']; - const k3 = 0.004 * Math.exp(50.0 / temperature); - const k4 = 0.012 * Math.exp(75.0 / temperature) * - Math.pow(temperature / 50.0, -2) * (1.0 + 1.0e-6 * pressure); - - // Analytical solutions - const A_conc = initialA * Math.exp(-k3 * currTime); - const B_conc = initialA * (k3 / (k4 - k3)) * - (Math.exp(-k3 * currTime) - Math.exp(-k4 * currTime)); - const C_conc = initialC + initialA * - (1.0 + (k3 * Math.exp(-k4 * currTime) - k4 * Math.exp(-k3 * currTime)) / (k4 - k3)); - const D_conc = initialD * Math.exp(-k1 * currTime); - const E_conc = initialD * (k1 / (k2 - k1)) * - (Math.exp(-k1 * currTime) - Math.exp(-k2 * currTime)); - const F_conc = initialF + initialD * - (1.0 + (k1 * Math.exp(-k2 * currTime) - k2 * Math.exp(-k1 * currTime)) / (k2 - k1)); - - // Check concentrations - assert.ok(isClose(conc.A[0], A_conc, tolerance), - `A concentration mismatch at t=${currTime}: ${conc.A[0]} vs ${A_conc}`); - assert.ok(isClose(conc.B[0], B_conc, tolerance), - `B concentration mismatch at t=${currTime}: ${conc.B[0]} vs ${B_conc}`); - assert.ok(isClose(conc.C[0], C_conc, tolerance), - `C concentration mismatch at t=${currTime}: ${conc.C[0]} vs ${C_conc}`); - assert.ok(isClose(conc.D[0], D_conc, tolerance), - `D concentration mismatch at t=${currTime}: ${conc.D[0]} vs ${D_conc}`); - assert.ok(isClose(conc.E[0], E_conc, tolerance), - `E concentration mismatch at t=${currTime}: ${conc.E[0]} vs ${E_conc}`); - assert.ok(isClose(conc.F[0], F_conc, tolerance), - `F concentration mismatch at t=${currTime}: ${conc.F[0]} vs ${F_conc}`); - - currTime += timeStep; - } -} - -/** - * Test multiple grid cells analytical solution - * Equivalent to TestMultipleGridCell in Python - */ -function testMultipleGridCell(solver, state, numGridCells, timeStep, places = 5) { - const concentrations = { - A: [], - B: [], - C: [], - D: [], - E: [], - F: [] - }; - const rateConstants = { - 'USER.reaction 1': [], - 'USER.reaction 2': [] - }; - const temperatures = []; - const pressures = []; - - // Generate random initial conditions for each grid cell - for (let i = 0; i < numGridCells; i++) { - temperatures.push(275.0 + (Math.random() - 0.5) * 100.0); - pressures.push(101253.3 + (Math.random() - 0.5) * 1000.0); - concentrations.A.push(0.75 + (Math.random() - 0.5) * 0.1); - concentrations.B.push(0); - concentrations.C.push(0.4 + (Math.random() - 0.5) * 0.1); - concentrations.D.push(0.8 + (Math.random() - 0.5) * 0.1); - concentrations.E.push(0); - concentrations.F.push(0.1 + (Math.random() - 0.5) * 0.1); - rateConstants['USER.reaction 1'].push(0.001 + (Math.random() - 0.5) * 0.0002); - rateConstants['USER.reaction 2'].push(0.002 + (Math.random() - 0.5) * 0.0002); + const conc = { A: [], B: [], C: [], D: [], E: [], F: [] }; + const rates = { 'USER.reaction 1': [], 'USER.reaction 2': [] }; + const temps = []; + const press = []; + + for (let i = 0; i < numCells; i++) { + temps.push(temperature); + press.push(pressure); + conc.A.push(0.75); conc.B.push(0); conc.C.push(0.4); + conc.D.push(0.8); conc.E.push(0); conc.F.push(0.1); + rates['USER.reaction 1'].push(0.001); + rates['USER.reaction 2'].push(0.002); } - state.setConditions({ temperatures, pressures }); // Air density calculated automatically - state.setConcentrations(concentrations); - state.setUserDefinedRateParameters(rateConstants); - - const initialConcentrations = state.getConcentrations(); - const initialRateParameters = state.getUserDefinedRateParameters(); - const conditions = state.getConditions(); - - // Verify initial conditions - for (let i = 0; i < numGridCells; i++) { - assert.ok(isClose(initialConcentrations.A[i], concentrations.A[i], 1e-13)); - assert.ok(isClose(initialConcentrations.B[i], concentrations.B[i], 1e-13)); - assert.ok(isClose(initialConcentrations.C[i], concentrations.C[i], 1e-13)); - assert.ok(isClose(initialConcentrations.D[i], concentrations.D[i], 1e-13)); - assert.ok(isClose(initialConcentrations.E[i], concentrations.E[i], 1e-13)); - assert.ok(isClose(initialConcentrations.F[i], concentrations.F[i], 1e-13)); - assert.ok(isClose(initialRateParameters['USER.reaction 1'][i], rateConstants['USER.reaction 1'][i], 1e-13)); - assert.ok(isClose(initialRateParameters['USER.reaction 2'][i], rateConstants['USER.reaction 2'][i], 1e-13)); - assert.ok(isClose(conditions.temperature[i], temperatures[i], 1e-13)); - assert.ok(isClose(conditions.pressure[i], pressures[i], 1e-13)); - const expectedAirDensity = pressures[i] / (8.31446261815324 * temperatures[i]); - assert.ok(isClose(conditions.air_density[i], expectedAirDensity, 1e-13)); - } - - timeStep = 1; - const simLength = 100; - - let currTime = timeStep; - const initialA = initialConcentrations.A.slice(); - const initialC = initialConcentrations.C.slice(); - const initialD = initialConcentrations.D.slice(); - const initialF = initialConcentrations.F.slice(); - - const k1 = []; - const k2 = []; - const k3 = []; - const k4 = []; - for (let i = 0; i < numGridCells; i++) { - k1.push(rateConstants['USER.reaction 1'][i]); - k2.push(rateConstants['USER.reaction 2'][i]); - k3.push(0.004 * Math.exp(50.0 / temperatures[i])); - k4.push(0.012 * Math.exp(75.0 / temperatures[i]) * - Math.pow(temperatures[i] / 50.0, -2) * (1.0 + 1.0e-6 * pressures[i])); - } + state.setConditions({ temperatures: temps, pressures: press, airDensities: Array(numCells).fill(airDensity) }); + state.setConcentrations(conc); + state.setUserDefinedRateParameters(rates); const tolerance = Math.pow(10, -places); - // Integrate and compare with analytical solution - while (currTime <= simLength) { - solver.solve(state, timeStep); - const conc = state.getConcentrations(); - - for (let i = 0; i < numGridCells; i++) { - // Analytical solutions - const A_conc = initialA[i] * Math.exp(-k3[i] * currTime); - const B_conc = initialA[i] * (k3[i] / (k4[i] - k3[i])) * - (Math.exp(-k3[i] * currTime) - Math.exp(-k4[i] * currTime)); - const C_conc = initialC[i] + initialA[i] * (1.0 + ( - k3[i] * Math.exp(-k4[i] * currTime) - k4[i] * Math.exp(-k3[i] * currTime)) / (k4[i] - k3[i])); - const D_conc = initialD[i] * Math.exp(-k1[i] * currTime); - const E_conc = initialD[i] * (k1[i] / (k2[i] - k1[i])) * - (Math.exp(-k1[i] * currTime) - Math.exp(-k2[i] * currTime)); - const F_conc = initialF[i] + initialD[i] * (1.0 + ( - k1[i] * Math.exp(-k2[i] * currTime) - k2[i] * Math.exp(-k1[i] * currTime)) / (k2[i] - k1[i])); - - // Check concentrations - assert.ok(isClose(conc.A[i], A_conc, tolerance), - `Grid cell ${i} of ${numGridCells}: A concentration mismatch. Initial A: ${initialConcentrations.A[i]}`); - assert.ok(isClose(conc.B[i], B_conc, tolerance), - `Grid cell ${i} of ${numGridCells}: B concentration mismatch. Initial B: ${initialConcentrations.B[i]}`); - assert.ok(isClose(conc.C[i], C_conc, tolerance), - `Grid cell ${i} of ${numGridCells}: C concentration mismatch. Initial C: ${initialConcentrations.C[i]}`); - assert.ok(isClose(conc.D[i], D_conc, tolerance), - `Grid cell ${i} of ${numGridCells}: D concentration mismatch. Initial D: ${initialConcentrations.D[i]}`); - assert.ok(isClose(conc.E[i], E_conc, tolerance), - `Grid cell ${i} of ${numGridCells}: E concentration mismatch. Initial E: ${initialConcentrations.E[i]}`); - assert.ok(isClose(conc.F[i], F_conc, tolerance), - `Grid cell ${i} of ${numGridCells}: F concentration mismatch. Initial F: ${initialConcentrations.F[i]}`); + for (let t = step; t <= 100; t += step) { + solver.solve(state, step); + const c = state.getConcentrations(); + + for (let i = 0; i < numCells; i++) { + const k1 = rates['USER.reaction 1'][i]; + const k2 = rates['USER.reaction 2'][i]; + const k3 = 0.004 * Math.exp(50.0 / temps[i]); + const k4 = 0.012 * Math.exp(75.0 / temps[i]) * + Math.pow(temps[i] / 50.0, -2) * (1.0 + 1.0e-6 * press[i]); + + const A = conc.A[i] * Math.exp(-k3 * t); + const B = conc.A[i] * (k3 / (k4 - k3)) * (Math.exp(-k3 * t) - Math.exp(-k4 * t)); + const C = conc.C[i] + conc.A[i] * (1.0 + (k3 * Math.exp(-k4 * t) - k4 * Math.exp(-k3 * t)) / (k4 - k3)); + const D = conc.D[i] * Math.exp(-k1 * t); + const E = conc.D[i] * (k1 / (k2 - k1)) * (Math.exp(-k1 * t) - Math.exp(-k2 * t)); + const F = conc.F[i] + conc.D[i] * (1.0 + (k1 * Math.exp(-k2 * t) - k2 * Math.exp(-k1 * t)) / (k2 - k1)); + + assert.ok(isClose(c.A[i], A, tolerance), `Grid cell ${i}: A mismatch`); + assert.ok(isClose(c.B[i], B, tolerance), `Grid cell ${i}: B mismatch`); + assert.ok(isClose(c.C[i], C, tolerance), `Grid cell ${i}: C mismatch`); + assert.ok(isClose(c.D[i], D, tolerance), `Grid cell ${i}: D mismatch`); + assert.ok(isClose(c.E[i], E, tolerance), `Grid cell ${i}: E mismatch`); + assert.ok(isClose(c.F[i], F, tolerance), `Grid cell ${i}: F mismatch`); } - - currTime += timeStep; } } -// Test suite for single grid cell - Standard Rosenbrock -describe('Analytical - Single grid cell - Standard Rosenbrock', () => { - it('should match analytical solution', async (t) => { - const solver = MICM.fromConfigPath( - CONFIG_PATH, - SolverType.rosenbrock_standard_order - ); - const state = solver.createState(1); - testSingleGridCell(solver, state, 200.0, 5); - }); -}); - -// Test suite for multiple grid cells - Standard Rosenbrock -describe('Analytical - Multiple grid cells - Standard Rosenbrock', () => { - for (let i = 1; i <= maxCells; i++) { - it(`should match analytical solution for ${i} grid cells`, async (t) => { - const solver = MICM.fromConfigPath( - CONFIG_PATH, - SolverType.rosenbrock_standard_order - ); - const state = solver.createState(i); - testMultipleGridCell(solver, state, i, 200.0, 5); - }); - } -}); - -// Test suite for single grid cell - Rosenbrock (vector-ordered) -describe('Analytical - Single grid cell - Rosenbrock', () => { - it('should match analytical solution', async (t) => { - const solver = MICM.fromConfigPath( - CONFIG_PATH, - SolverType.rosenbrock - ); - const state = solver.createState(1); - testSingleGridCell(solver, state, 200.0, 5); - }); -}); - -// Test suite for multiple grid cells - Rosenbrock (vector-ordered) -describe('Analytical - Multiple grid cells - Rosenbrock', () => { - for (let i = 1; i <= maxCells; i++) { - it(`should match analytical solution for ${i} grid cells`, async (t) => { - const solver = MICM.fromConfigPath( - CONFIG_PATH, - SolverType.rosenbrock - ); - const state = solver.createState(i); - testMultipleGridCell(solver, state, i, 200.0, 5); - }); - } -}); - -// Test suite for single grid cell - Backward Euler -describe('Analytical - Single grid cell - Backward Euler', () => { - it('should match analytical solution', () => { - const solver = MICM.fromConfigPath( - CONFIG_PATH, - SolverType.backward_euler - ); - const state = solver.createState(1); - testSingleGridCell(solver, state, 10.0, 2); - }); -}); - -// Test suite for multiple grid cells - Backward Euler -describe('Analytical - Multiple grid cells - Backward Euler', () => { - for (let i = 1; i <= maxCells; i++) { - it(`should match analytical solution for ${i} grid cells`, () => { - const solver = MICM.fromConfigPath( - CONFIG_PATH, - SolverType.backward_euler - ); - const state = solver.createState(i); - testMultipleGridCell(solver, state, i, 10.0, 2); - }); - } -}); - -// Test suite for single grid cell - Backward Euler Standard Order -describe('Analytical - Single grid cell - Backward Euler Standard Order', () => { - it('should match analytical solution', () => { - const solver = MICM.fromConfigPath( - CONFIG_PATH, - SolverType.backward_euler_standard_order - ); - const state = solver.createState(1); - testSingleGridCell(solver, state, 10.0, 2); +const solvers = [ + { name: 'Rosenbrock', type: SolverType.rosenbrock, step: 200.0, places: 5 }, + { name: 'Rosenbrock Standard', type: SolverType.rosenbrock_standard_order, step: 200.0, places: 5 }, + { name: 'Backward Euler', type: SolverType.backward_euler, step: 10.0, places: 2 }, + { name: 'Backward Euler Standard', type: SolverType.backward_euler_standard_order, step: 10.0, places: 2 }, +]; +const gridCellsList = [1, 5, 10, 20]; + +solvers.forEach(({ name, type, step, places }) => { + describe(`Analytical - ${name}`, () => { + for (const nCells of gridCellsList) { + it(`should match analytical solution for ${nCells} grid cell${nCells > 1 ? 's' : ''}`, () => { + const solver = MICM.fromConfigPath(CONFIG_PATH, type); + const state = solver.createState(nCells); + testAnalytical(solver, state, nCells, step, places); + }); + } }); }); - -// Test suite for multiple grid cells - Backward Euler Standard Order -describe('Analytical - Multiple grid cells - Backward Euler Standard Order', () => { - for (let i = 1; i <= maxCells; i++) { - it(`should match analytical solution for ${i} grid cells`, () => { - const solver = MICM.fromConfigPath( - CONFIG_PATH, - SolverType.backward_euler_standard_order - ); - const state = solver.createState(i); - testMultipleGridCell(solver, state, i, 10.0, 2); - }); - } -}); diff --git a/javascript/tests/unit/micm.test.js b/javascript/tests/unit/micm.test.js index a9a6a1ff6..cbdb1d52d 100644 --- a/javascript/tests/unit/micm.test.js +++ b/javascript/tests/unit/micm.test.js @@ -29,24 +29,53 @@ function getConfigPath() { return path.join(__dirname, '../../../configs/v0/analytical'); } -describe('MICM Initialization', () => { - it('should initialize with fromConfigPath', async (t) => { - const micm = MICM.fromConfigPath(getConfigPath()); - assert.ok(micm, 'MICM should be created'); - assert.ok(micm.solverType() !== null, 'Solver type should be set'); +function createTestMechanism() { + // Create a simple mechanism for testing + const A = new Species({ name: 'A' }); + const B = new Species({ name: 'B' }); + const C = new Species({ name: 'C' }); + + const gas = new Phase({ + name: 'gas', + species: [A, B, C] + }); + + const reaction1 = new reactionTypes.UserDefined({ + name: 'reaction 1', + gas_phase: 'gas', + reactants: [new ReactionComponent({ species_name: 'A' })], + products: [new ReactionComponent({ species_name: 'B' })] }); + const reaction2 = new reactionTypes.UserDefined({ + name: 'reaction 2', + gas_phase: 'gas', + reactants: [new ReactionComponent({ species_name: 'B' })], + products: [new ReactionComponent({ species_name: 'C' })] + }); + + const mechanism = new Mechanism({ + name: 'Test Mechanism', + version: '1.0.0', + species: [A, B, C], + phases: [gas], + reactions: [reaction1, reaction2] + }); + return mechanism; +} + +describe('MICM Initialization', () => { it('should initialize with fromConfigPath and solver_type', async (t) => { - const micm = MICM.fromConfigPath( - getConfigPath(), - SolverType.rosenbrock_standard_order - ); - assert.ok(micm, 'MICM should be created'); - assert.strictEqual( - micm.solverType(), - SolverType.rosenbrock_standard_order, - 'Solver type should match' - ); + const types = Object.values(SolverType); + for (const solverType of types) { + const micm = MICM.fromConfigPath(getConfigPath(), solverType); + assert.ok(micm, 'MICM should be created'); + assert.strictEqual( + micm.solverType(), + solverType, + `Solver type should match ${solverType}` + ); + } }); it('should throw error with invalid config_path type', async (t) => { @@ -57,79 +86,34 @@ describe('MICM Initialization', () => { ); }); - it('should use default solver type', async (t) => { - const micm = MICM.fromConfigPath(getConfigPath()); - assert.strictEqual( - micm.solverType(), - SolverType.rosenbrock_standard_order, - 'Default solver type should be rosenbrock_standard_order' - ); - }); - - it('should initialize with backward_euler_standard_order', async (t) => { - const micm = MICM.fromConfigPath( - getConfigPath(), - SolverType.backward_euler_standard_order - ); - assert.strictEqual( - micm.solverType(), - SolverType.backward_euler_standard_order, - 'Solver type should be backward_euler_standard_order' - ); - }); - it('should initialize and solve with fromMechanism', async (t) => { - // Create a simple mechanism for testing - const A = new Species({ name: 'A' }); - const B = new Species({ name: 'B' }); - const C = new Species({ name: 'C' }); - - const gas = new Phase({ - name: 'gas', - species: [A, B, C] - }); - - const reaction1 = new reactionTypes.UserDefined({ - name: 'reaction 1', - gas_phase: 'gas', - reactants: [new ReactionComponent({ species_name: 'A' })], - products: [new ReactionComponent({ species_name: 'B' })] - }); - - const reaction2 = new reactionTypes.UserDefined({ - name: 'reaction 2', - gas_phase: 'gas', - reactants: [new ReactionComponent({ species_name: 'B' })], - products: [new ReactionComponent({ species_name: 'C' })] - }); - - const mechanism = new Mechanism({ - name: 'Test Mechanism', - version: '1.0.0', - species: [A, B, C], - phases: [gas], - reactions: [reaction1, reaction2] - }); - - const micm = MICM.fromMechanism(mechanism); - assert.ok(micm, 'MICM should be created from mechanism'); - assert.ok(micm.solverType() !== null, 'Solver type should be set'); + const types = Object.values(SolverType); + const mechanism = createTestMechanism(); + for (const solverType of types) { + const micm = MICM.fromMechanism(mechanism, solverType); + assert.ok(micm, 'MICM should be created from mechanism'); + assert.strictEqual( + micm.solverType(), + solverType, + `Solver type should match ${solverType}` + ); - const state = micm.createState(1); - state.setConcentrations({ A: [1.0], B: [2.0], C: [3.0] }); - state.setConditions({ temperatures: [298.15], pressures: [101325.0], air_densities: [1.0] }); - state.setUserDefinedRateParameters({ - 'USER.reaction 1': 0.001, - 'USER.reaction 2': 0.002 - }); + const state = micm.createState(1); + state.setConcentrations({ A: [1.0], B: [2.0], C: [3.0] }); + state.setConditions({ temperatures: [298.15], pressures: [101325.0], air_densities: [1.0] }); + state.setUserDefinedRateParameters({ + 'USER.reaction 1': 0.001, + 'USER.reaction 2': 0.002 + }); - const result = micm.solve(state, 60.0); - assert.ok(result, 'MICM.solve should return a result'); - let concentrations = state.getConcentrations(); + const result = micm.solve(state, 60.0); + assert.ok(result, 'MICM.solve should return a result'); + let concentrations = state.getConcentrations(); - assert.ok(concentrations.A !== 1.0, 'Concentration of A should have changed after solve'); - assert.ok(concentrations.B !== 2.0, 'Concentration of B should have changed after solve'); - assert.ok(concentrations.C !== 3.0, 'Concentration of C should have changed after solve'); + assert.ok(concentrations.A !== 1.0, 'Concentration of A should have changed after solve'); + assert.ok(concentrations.B !== 2.0, 'Concentration of B should have changed after solve'); + assert.ok(concentrations.C !== 3.0, 'Concentration of C should have changed after solve'); + } }); it('should throw error with invalid mechanism', async (t) => { @@ -149,36 +133,6 @@ describe('MICM Initialization', () => { }); }); -describe('MICM solverType method', () => { - it('should return correct solver type', async (t) => { - const micm = MICM.fromConfigPath( - getConfigPath(), - SolverType.rosenbrock_standard_order - ); - assert.strictEqual( - micm.solverType(), - SolverType.rosenbrock_standard_order, - 'Should return rosenbrock_standard_order' - ); - }); - - it('should work with different solver types', async (t) => { - const solverTypes = [ - SolverType.rosenbrock_standard_order, - SolverType.backward_euler_standard_order, - ]; - - for (const solverType of solverTypes) { - const micm = MICM.fromConfigPath(getConfigPath(), solverType); - assert.strictEqual( - micm.solverType(), - solverType, - `Solver type should match ${solverType}` - ); - } - }); -}); - describe('MICM createState method', () => { let micm; @@ -233,254 +187,139 @@ describe('MICM createState method', () => { }); }); -describe('MICM solve method', () => { - let micm; - - it('should solve with valid state and timestep', async (t) => { - micm = MICM.fromConfigPath(getConfigPath()); - const state = micm.createState(1); - - // Set initial conditions - state.setConditions({ - temperatures: 298.15, - pressures: 101325.0, - air_densities: 1.2 - }); - state.setConcentrations({ A: 1.0, B: 0.0, C: 0.5 }); - state.setUserDefinedRateParameters({ - 'USER.reaction 1': 0.001, - 'USER.reaction 2': 0.002 - }); - - // Solve - const result = micm.solve(state, 1.0); - - assert.ok(result instanceof SolverResult, 'Should return SolverResult'); - assert.ok(result.stats instanceof SolverStats, 'Should have SolverStats'); - }); - - it('should solve with float timestep', async (t) => { - micm = MICM.fromConfigPath(getConfigPath()); - const state = micm.createState(1); - - state.setConditions({ - temperatures: 298.15, - pressures: 101325.0, - air_densities: 1.2 - }); - state.setConcentrations({ A: 1.0 }); - state.setUserDefinedRateParameters({ - 'USER.reaction 1': 0.001, - 'USER.reaction 2': 0.002 - }); - - const result = micm.solve(state, 1.5); - assert.ok(result instanceof SolverResult, 'Should solve with float timestep'); - }); - - it('should solve with integer timestep', async (t) => { - micm = MICM.fromConfigPath(getConfigPath()); - const state = micm.createState(1); - - state.setConditions({ - temperatures: 298.15, - pressures: 101325.0, - air_densities: 1.2 - }); - state.setConcentrations({ A: 1.0 }); - state.setUserDefinedRateParameters({ - 'USER.reaction 1': 0.001, - 'USER.reaction 2': 0.002 - }); - - const result = micm.solve(state, 1); - assert.ok(result instanceof SolverResult, 'Should solve with integer timestep'); - }); +describe('MICM Solve - comprehensive', () => { + const gridCellsList = [1, 5, 10, 20]; + const solverTypes = Object.values(SolverType); - it('should throw error for invalid state type', async (t) => { - micm = MICM.fromConfigPath(getConfigPath()); - - assert.throws( - () => micm.solve('not a state', 1.0), - /state must be an instance of State/, - 'Should throw TypeError for invalid state' - ); + let mechanism; + before(() => { + mechanism = createTestMechanism(); }); - it('should throw error for invalid timestep type', async (t) => { - micm = MICM.fromConfigPath(getConfigPath()); - const state = micm.createState(1); - - assert.throws( - () => micm.solve(state, 'not a number'), - /timeStep must be a number/, - 'Should throw TypeError for invalid timestep' - ); - }); - - it('should throw error for null state', async (t) => { - micm = MICM.fromConfigPath(getConfigPath()); - - assert.throws( - () => micm.solve(null, 1.0), - /state must be an instance of State/, - 'Should throw TypeError for null state' - ); - }); - - it('should throw error for undefined state', async (t) => { - micm = MICM.fromConfigPath(getConfigPath()); - - assert.throws( - () => micm.solve(undefined, 1.0), - /state must be an instance of State/, - 'Should throw TypeError for undefined state' - ); - }); - - it('should solve multiple times with same state', async (t) => { - micm = MICM.fromConfigPath(getConfigPath()); - const state = micm.createState(1); - - state.setConditions({ - temperatures: 298.15, - pressures: 101325.0, - air_densities: 1.2 - }); - state.setConcentrations({ A: 1.0, B: 0.0 }); - state.setUserDefinedRateParameters({ - 'USER.reaction 1': 0.001, - 'USER.reaction 2': 0.002 - }); - - // Solve multiple times - for (let i = 0; i < 5; i++) { - const result = micm.solve(state, 1.0); - assert.ok( - result instanceof SolverResult, - `Should solve on iteration ${i + 1}` - ); + it('should solve from a mechanism in code', async (t) => { + for (const solverType of solverTypes) { + const micm = MICM.fromMechanism(mechanism, solverType); + + for (const nCells of gridCellsList) { + const state = micm.createState(nCells); + + // Set initial concentrations and conditions + const conc = { A: Array(nCells).fill(1.0), B: Array(nCells).fill(0.0), C: Array(nCells).fill(0.5) }; + state.setConcentrations(conc); + state.setConditions({ + temperatures: Array(nCells).fill(298.15), + pressures: Array(nCells).fill(101325.0), + airDensities: Array(nCells).fill(1.0) + }); + state.setUserDefinedRateParameters({ + 'USER.reaction 1': Array(nCells).fill(0.001), + 'USER.reaction 2': Array(nCells).fill(0.002) + }); + + const result = micm.solve(state, 1.0); + assert.ok(result instanceof SolverResult, 'Should return SolverResult'); + assert.ok(result.stats instanceof SolverStats, 'Should return SolverStats'); + + const updated = state.getConcentrations(); + for (let i = 0; i < nCells; i++) { + assert.ok(updated.A[i] !== 1.0, `A[${i}] should have changed`); + assert.ok(updated.B[i] !== 0.0, `B[${i}] should have changed`); + assert.ok(updated.C[i] !== 0.5, `C[${i}] should have changed`); + } + } } }); -}); - -describe('MICM Integration Tests', () => { - it('should complete end-to-end workflow', async (t) => { - // Initialize solver - const micm = MICM.fromConfigPath( - getConfigPath(), - SolverType.rosenbrock_standard_order - ); - - // Create state - const state = micm.createState(1); - - // Set conditions - state.setConditions({ - temperatures: 298.15, - pressures: 101325.0, - air_densities: 1.2 - }); - - state.setConcentrations({ - A: 0.75, - B: 0.0, - C: 0.4, - D: 0.8, - E: 0.0, - F: 0.1 - }); - - state.setUserDefinedRateParameters({ - 'USER.reaction 1': 0.001, - 'USER.reaction 2': 0.002 - }); - - // Solve - const result = micm.solve(state, 1.0); - // Verify results - assert.ok(result instanceof SolverResult, 'Should return SolverResult'); - assert.ok(result.stats instanceof SolverStats, 'Should have SolverStats'); + it('should solve from a config file', async (t) => { + for (const solverType of solverTypes) { + const micm = MICM.fromConfigPath(getConfigPath(), solverType); - // Get updated concentrations - const concentrations = state.getConcentrations(); - assert.ok('A' in concentrations, 'Should have species A'); - assert.ok('B' in concentrations, 'Should have species B'); - assert.ok('C' in concentrations, 'Should have species C'); + for (const nCells of gridCellsList) { + const state = micm.createState(nCells); + + const conc = { A: Array(nCells).fill(0.5), B: Array(nCells).fill(1.0), C: Array(nCells).fill(0.0) }; + state.setConcentrations(conc); + state.setConditions({ + temperatures: Array(nCells).fill(300), + pressures: Array(nCells).fill(101325), + airDensities: Array(nCells).fill(1.2) + }); + state.setUserDefinedRateParameters({ + 'USER.reaction 1': Array(nCells).fill(0.002), + 'USER.reaction 2': Array(nCells).fill(0.004) + }); + + const result = micm.solve(state, 2.0); + assert.ok(result instanceof SolverResult, 'Should return SolverResult'); + assert.ok(result.stats instanceof SolverStats, 'Should have stats'); + + const updated = state.getConcentrations(); + for (let i = 0; i < nCells; i++) { + assert.ok(updated.A[i] !== 0.5, `A[${i}] should have changed`); + assert.ok(updated.B[i] !== 1.0, `B[${i}] should have changed`); + assert.ok(updated.C[i] !== 0.0, `C[${i}] should have changed`); + } + } + } }); it('should handle multiple solve iterations', async (t) => { - const micm = MICM.fromConfigPath(getConfigPath()); - const state = micm.createState(1); + const micm = MICM.fromMechanism(mechanism, solverTypes[0]); + const state = micm.createState(5); + state.setConcentrations({ A: Array(5).fill(1.0), B: Array(5).fill(0.0), C: Array(5).fill(0.5) }); state.setConditions({ - temperatures: 298.15, - pressures: 101325.0, - air_densities: 1.2 - }); - - state.setConcentrations({ - A: 1.0, - B: 0.0, - C: 0.5 + temperatures: Array(5).fill(298.15), + pressures: Array(5).fill(101325.0), + airDensities: Array(5).fill(1.0) }); - state.setUserDefinedRateParameters({ - 'USER.reaction 1': 0.001, - 'USER.reaction 2': 0.002 + 'USER.reaction 1': Array(5).fill(0.001), + 'USER.reaction 2': Array(5).fill(0.002) }); - const numIterations = 10; - for (let i = 0; i < numIterations; i++) { + for (let i = 0; i < 10; i++) { const result = micm.solve(state, 0.5); - assert.ok( - result instanceof SolverResult, - `Iteration ${i + 1} should succeed` - ); - assert.ok( - typeof result.stats.final_time === 'number', - 'Should have final_time' - ); + assert.ok(result instanceof SolverResult, `Iteration ${i + 1} should succeed`); + assert.ok(result.stats instanceof SolverStats, 'Should have stats'); } - // Verify concentrations changed - const finalConcentrations = state.getConcentrations(); - assert.ok(finalConcentrations.A[0] !== 1.0, 'Concentration A should have changed'); + const final = state.getConcentrations(); + for (let i = 0; i < 5; i++) { + assert.ok(final.A[i] !== 1.0, `A[${i}] should have changed`); + } }); - it('should work with different solver types', async (t) => { - const solverTypes = [ - SolverType.rosenbrock_standard_order, - SolverType.backward_euler_standard_order, - ]; - + it('should work with all solver types', async (t) => { for (const solverType of solverTypes) { const micm = MICM.fromConfigPath(getConfigPath(), solverType); + const state = micm.createState(2); - const state = micm.createState(1); - + state.setConcentrations({ A: [1.0, 0.5], B: [0.0, 0.5], C: [0.5, 1.0] }); state.setConditions({ - temperatures: 298.15, - pressures: 101325.0, - air_densities: 1.2 + temperatures: [298.15, 300], + pressures: [101325.0, 101325.0], + airDensities: [1.0, 1.2] }); - - state.setConcentrations({ A: 1.0, B: 0.0 }); state.setUserDefinedRateParameters({ - 'USER.reaction 1': 0.001, - 'USER.reaction 2': 0.002 + 'USER.reaction 1': [0.001, 0.002], + 'USER.reaction 2': [0.002, 0.003] }); const result = micm.solve(state, 1.0); - assert.ok( - result instanceof SolverResult, - `Should solve with solver type ${solverType}` - ); + assert.ok(result instanceof SolverResult, `Should solve with solver ${solverType}`); + + const updated = state.getConcentrations(); + for (let i = 0; i < 2; i++) { + assert.ok(updated.A[i] !== (i === 0 ? 1.0 : 0.5), `A[${i}] should have changed`); + assert.ok(updated.B[i] !== (i === 0 ? 0.0 : 0.5), `B[${i}] should have changed`); + assert.ok(updated.C[i] !== (i === 0 ? 0.5 : 1.0), `C[${i}] should have changed`); + } } }); }); + describe('MICM SolverResult validation', () => { it('should return valid SolverResult structure', async (t) => { const micm = MICM.fromConfigPath(getConfigPath()); diff --git a/javascript/tests/unit/state.test.js b/javascript/tests/unit/state.test.js index 915af85f9..4b457ed12 100644 --- a/javascript/tests/unit/state.test.js +++ b/javascript/tests/unit/state.test.js @@ -25,7 +25,7 @@ before(async () => { * Helper function to create a test mechanism * This creates a simple mechanism for testing state operations */ -function createTestMechanism() { +function getConfigPath() { // For JavaScript, we'll use a config path instead of mechanism object // as the mechanism configuration API might not be fully exposed return path.join(__dirname, '../../../configs/v0/analytical'); @@ -33,21 +33,21 @@ function createTestMechanism() { describe('State initialization', () => { it('should create state with single grid cell', async (t) => { - const configPath = createTestMechanism(); + const configPath = getConfigPath(); const solver = MICM.fromConfigPath(configPath); const state = solver.createState(1); assert.ok(state, 'State should be created'); }); it('should create state with multiple grid cells', async (t) => { - const configPath = createTestMechanism(); + const configPath = getConfigPath(); const solver = MICM.fromConfigPath(configPath); const state = solver.createState(3); assert.ok(state, 'State with 3 grid cells should be created'); }); it('should throw error for invalid grid cell count', async (t) => { - const configPath = createTestMechanism(); + const configPath = getConfigPath(); const solver = MICM.fromConfigPath(configPath); assert.throws( () => solver.createState(0), @@ -60,177 +60,155 @@ describe('State initialization', () => { describe('Concentrations', () => { let solver; let state; + const gridCellsList = [1, 5, 10, 20]; before(() => { - const configPath = createTestMechanism(); + const configPath = getConfigPath(); solver = MICM.fromConfigPath(configPath); }); - it('should set and get concentrations for single grid cell', async (t) => { - state = solver.createState(1); - const concentrations = { A: 1.0, B: 2.0, C: 3.0 }; - state.setConcentrations(concentrations); - const result = state.getConcentrations(); + it('should set and get concentrations for various numbers of grid cells', async () => { + for (const nCells of gridCellsList) { + state = solver.createState(nCells); - assert.ok(isClose(result.A[0], 1.0, 1e-13), 'A concentration should be 1.0'); - assert.ok(isClose(result.B[0], 2.0, 1e-13), 'B concentration should be 2.0'); - assert.ok(isClose(result.C[0], 3.0, 1e-13), 'C concentration should be 3.0'); - }); + // Set concentrations + const concentrations = { + A: Array.from({ length: nCells }, (_, i) => i + 1), + B: Array.from({ length: nCells }, (_, i) => i + 10), + C: Array.from({ length: nCells }, (_, i) => i + 100) + }; + state.setConcentrations(concentrations); - it('should set and get concentrations for multiple grid cells', async (t) => { - state = solver.createState(2); - const concentrations = { A: [1.0, 2.0], B: [3.0, 4.0], C: [5.0, 6.0] }; - state.setConcentrations(concentrations); - const result = state.getConcentrations(); + const result = state.getConcentrations(); - assert.deepStrictEqual(result.A, [1.0, 2.0], 'A concentrations should match'); - assert.deepStrictEqual(result.B, [3.0, 4.0], 'B concentrations should match'); - assert.deepStrictEqual(result.C, [5.0, 6.0], 'C concentrations should match'); + for (let i = 0; i < nCells; i++) { + assert.ok(isClose(result.A[i], concentrations.A[i], 1e-13), `A[${i}] should match`); + assert.ok(isClose(result.B[i], concentrations.B[i], 1e-13), `B[${i}] should match`); + assert.ok(isClose(result.C[i], concentrations.C[i], 1e-13), `C[${i}] should match`); + } + } }); - it('should handle empty concentration update', async (t) => { - state = solver.createState(1); - const concentrations = { A: 1.0, B: 2.0, C: 3.0 }; - state.setConcentrations(concentrations); + it('should handle empty concentration update without changing values', async () => { + for (const nCells of gridCellsList) { + state = solver.createState(nCells); + + const concentrations = { + A: Array(nCells).fill(1.0), + B: Array(nCells).fill(2.0), + C: Array(nCells).fill(3.0) + }; + state.setConcentrations(concentrations); - // Set empty concentrations - should not change values - state.setConcentrations({}); - const result = state.getConcentrations(); + // Empty update + state.setConcentrations({}); + const result = state.getConcentrations(); - assert.ok(isClose(result.A[0], 1.0, 1e-13), 'A concentration should remain 1.0'); - assert.ok(isClose(result.B[0], 2.0, 1e-13), 'B concentration should remain 2.0'); - assert.ok(isClose(result.C[0], 3.0, 1e-13), 'C concentration should remain 3.0'); + for (let i = 0; i < nCells; i++) { + assert.ok(isClose(result.A[i], concentrations.A[i], 1e-13), `A[${i}] should remain`); + assert.ok(isClose(result.B[i], concentrations.B[i], 1e-13), `B[${i}] should remain`); + assert.ok(isClose(result.C[i], concentrations.C[i], 1e-13), `C[${i}] should remain`); + } + } }); }); + describe('Conditions', () => { let solver; let state; + const gridCellsList = [1, 5, 10, 20]; before(() => { - const configPath = createTestMechanism(); + const configPath = getConfigPath(); solver = MICM.fromConfigPath(configPath); }); - it('should set and get conditions for single grid cell', async (t) => { - state = solver.createState(1); - state.setConditions({ temperatures: 300.0, pressures: 101325.0 }); - const conditions = state.getConditions(); - - assert.ok(isClose(conditions.temperature[0], 300.0, 1e-13), 'Temperature should be 300.0'); - assert.ok(isClose(conditions.pressure[0], 101325.0, 1e-13), 'Pressure should be 101325.0'); - // Air density should be calculated: P / (R * T) where R = 8.31446261815324 - const expectedAirDensity = 101325.0 / (8.31446261815324 * 300.0); - assert.ok(isClose(conditions.air_density[0], expectedAirDensity, 0.1), 'Air density should be calculated'); - }); - - it('should set and get conditions for multiple grid cells', async (t) => { - state = solver.createState(2); - state.setConditions({ - temperatures: [300.0, 310.0], - pressures: [101325.0, 101325.0], - air_densities: [40.9, 39.5] - }); - const conditions = state.getConditions(); - - assert.deepStrictEqual(conditions.temperature, [300.0, 310.0], 'Temperatures should match'); - assert.deepStrictEqual(conditions.pressure, [101325.0, 101325.0], 'Pressures should match'); - assert.deepStrictEqual(conditions.air_density, [40.9, 39.5], 'Air densities should match'); - }); + it('should set and get conditions for various numbers of grid cells', async () => { + for (const nCells of gridCellsList) { + state = solver.createState(nCells); - it('should accept integer values for conditions', async (t) => { - state = solver.createState(1); - // Test setting int values (from Python test) - state.setConditions({ temperatures: 272, pressures: 101325 }); - const conditions = state.getConditions(); - - assert.ok(isClose(conditions.temperature[0], 272, 1e-13), 'Temperature should be 272'); - assert.ok(isClose(conditions.pressure[0], 101325, 1e-13), 'Pressure should be 101325'); - }); -}); - -describe('User-defined rate parameters', () => { - let solver; - let state; + // Create test values + const temperatures = Array.from({ length: nCells }, (_, i) => 300.0 + i); + const pressures = Array.from({ length: nCells }, () => 101325.0); + const airDensities = temperatures.map((T, i) => 101325.0 / (8.31446261815324 * T)); - before(() => { - const configPath = createTestMechanism(); - solver = MICM.fromConfigPath(configPath); - }); + state.setConditions({ temperatures, pressures }); + const conditions = state.getConditions(); - it('should set and get user-defined rate parameters for single grid cell', async (t) => { - state = solver.createState(1); - const params = { 'USER.reaction 1': 1.0 }; - state.setUserDefinedRateParameters(params); - const result = state.getUserDefinedRateParameters(); + assert.strictEqual(conditions.length, nCells, `Should have ${nCells} grid cells`); - assert.ok(isClose(result['USER.reaction 1'][0], 1.0, 1e-13), 'Rate parameter should be 1.0'); + for (let i = 0; i < nCells; i++) { + assert.ok(isClose(conditions[i].temperature, temperatures[i], 1e-13), `Temperature[${i}] should match`); + assert.ok(isClose(conditions[i].pressure, pressures[i], 1e-13), `Pressure[${i}] should match`); + assert.ok(isClose(conditions[i].air_density, airDensities[i], 0.1), `Air density[${i}] should match`); + } + } }); - it('should set and get user-defined rate parameters for multiple grid cells', async (t) => { - state = solver.createState(2); - const params = { 'USER.reaction 1': [1.0, 2.0] }; - state.setUserDefinedRateParameters(params); - const result = state.getUserDefinedRateParameters(); + it('should handle integer values for conditions', async () => { + for (const nCells of gridCellsList) { + state = solver.createState(nCells); - assert.deepStrictEqual(result['USER.reaction 1'], [1.0, 2.0], 'Rate parameters should match'); - }); + const temperatures = Array.from({ length: nCells }, () => 272); + const pressures = Array.from({ length: nCells }, () => 101325); + const airDensities = temperatures.map(T => 101325 / (8.31446261815324 * T)); - it('should handle empty rate parameter update', async (t) => { - state = solver.createState(1); - const params = { 'USER.reaction 1': 1.0 }; - state.setUserDefinedRateParameters(params); + state.setConditions({ temperatures, pressures }); + const conditions = state.getConditions(); - // Set empty parameters - should not change values - state.setUserDefinedRateParameters({}); - const result = state.getUserDefinedRateParameters(); + assert.strictEqual(conditions.length, nCells, `Should have ${nCells} grid cells`); - assert.ok(isClose(result['USER.reaction 1'][0], 1.0, 1e-13), 'Rate parameter should remain 1.0'); + for (let i = 0; i < nCells; i++) { + assert.ok(isClose(conditions[i].temperature, 272, 1e-13), `Temperature[${i}] should be 272`); + assert.ok(isClose(conditions[i].pressure, 101325, 1e-13), `Pressure[${i}] should be 101325`); + assert.ok(isClose(conditions[i].air_density, airDensities[i], 1e-13), `Air density[${i}] should be calculated`); + } + } }); }); -describe('State ordering', () => { + +describe('User-defined rate parameters', () => { let solver; let state; + const gridCellsList = [1, 5, 10, 20]; before(() => { - const configPath = createTestMechanism(); + const configPath = getConfigPath(); solver = MICM.fromConfigPath(configPath); - state = solver.createState(1); }); - it('should get species ordering', async (t) => { - const ordering = state.getSpeciesOrdering(); + it('should set and get user-defined rate parameters for multiple grid cells', async () => { + for (const nCells of gridCellsList) { + state = solver.createState(nCells); - assert.ok(typeof ordering === 'object', 'Ordering should be an object'); - assert.ok('A' in ordering, 'Should have species A'); - assert.ok('B' in ordering, 'Should have species B'); - assert.ok('C' in ordering, 'Should have species C'); - assert.ok(ordering.A >= 0, 'A ordering should be non-negative'); - assert.ok(ordering.B >= 0, 'B ordering should be non-negative'); - assert.ok(ordering.C >= 0, 'C ordering should be non-negative'); - }); - - it('should get user-defined rate parameters ordering', async (t) => { - const ordering = state.getUserDefinedRateParametersOrdering(); + const params = { 'USER.reaction 1': Array.from({ length: nCells }, (_, i) => 1.0 + i) }; + state.setUserDefinedRateParameters(params); + const result = state.getUserDefinedRateParameters(); - assert.ok(typeof ordering === 'object', 'Ordering should be an object'); - assert.ok('USER.reaction 1' in ordering, 'Should have USER.reaction 1'); - assert.ok('USER.reaction 2' in ordering, 'Should have USER.reaction 2'); - assert.ok(ordering['USER.reaction 1'] >= 0, 'Ordering should be non-negative'); - assert.ok(ordering['USER.reaction 2'] >= 0, 'Ordering should be non-negative'); + assert.strictEqual(result['USER.reaction 1'].length, nCells, `Should have ${nCells} values`); + for (let i = 0; i < nCells; i++) { + assert.ok(isClose(result['USER.reaction 1'][i], 1.0 + i, 1e-13), `Rate parameter[${i}] should match`); + } + } }); -}); -describe('Grid cell operations', () => { - it('should return correct number of grid cells', async (t) => { - const configPath = createTestMechanism(); - const solver = MICM.fromConfigPath(configPath); + it('should handle empty rate parameter update', async () => { + for (const nCells of gridCellsList) { + state = solver.createState(nCells); + + const params = { 'USER.reaction 1': Array.from({ length: nCells }, () => 1.0) }; + state.setUserDefinedRateParameters(params); - const state1 = solver.createState(1); - assert.strictEqual(state1.getNumberOfGridCells(), 1, 'Should have 1 grid cell'); + // Set empty parameters - should not change values + state.setUserDefinedRateParameters({}); + const result = state.getUserDefinedRateParameters(); - const state5 = solver.createState(5); - assert.strictEqual(state5.getNumberOfGridCells(), 5, 'Should have 5 grid cells'); + for (let i = 0; i < nCells; i++) { + assert.ok(isClose(result['USER.reaction 1'][i], 1.0, 1e-13), `Rate parameter[${i}] should remain 1.0`); + } + } }); }); + diff --git a/package-lock.json b/package-lock.json index c8a12750c..1cfc50fe6 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,13 @@ { "name": "@ncar/musica", - "version": "0.14.1", + "version": "0.14.2", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@ncar/musica", - "version": "0.14.1", + "version": "0.14.2", + "license": "Apache-2.0", "devDependencies": { "c8": "^10.1.3" } diff --git a/package.json b/package.json index 40537831c..5b2dc7ab7 100644 --- a/package.json +++ b/package.json @@ -15,7 +15,7 @@ ], "scripts": { "build": "npm run build:wasm", - "build:wasm": "emcmake cmake -S . -B build-wasm -DCMAKE_BUILD_TYPE=Debug -DMUSICA_ENABLE_JAVASCRIPT=ON -DMUSICA_ENABLE_TUVX=OFF -DMUSICA_ENABLE_CARMA=OFF && cmake --build build-wasm", + "build:wasm": "emcmake cmake -S . -B build-wasm -DCMAKE_BUILD_TYPE=Debug -DMUSICA_ENABLE_JAVASCRIPT=ON -DMUSICA_ENABLE_TUVX=OFF -DMUSICA_ENABLE_CARMA=OFF -DMUSICA_ENABLE_TESTS=OFF && cmake --build build-wasm", "test": "node --test javascript/tests/unit/*.test.js javascript/tests/integration/*.test.js", "test:integration": "node --test javascript/tests/integration/*.test.js", "test:unit": "node --test javascript/tests/unit/**/*.test.js", diff --git a/python/bindings/micm/micm.cpp b/python/bindings/micm/micm.cpp index 7a36b43ca..d8516eed1 100644 --- a/python/bindings/micm/micm.cpp +++ b/python/bindings/micm/micm.cpp @@ -34,20 +34,7 @@ void bind_micm(py::module_& micm) .value("AcceptingUnconvergedIntegration", micm::SolverState::AcceptingUnconvergedIntegration) .export_values(); - micm.def( - "_vector_size", - [](const musica::MICMSolver solver_type) - { - switch (solver_type) - { - case musica::MICMSolver::Rosenbrock: - case musica::MICMSolver::BackwardEuler: - case musica::MICMSolver::CudaRosenbrock: return musica::MUSICA_VECTOR_SIZE; - case musica::MICMSolver::RosenbrockStandardOrder: - case musica::MICMSolver::BackwardEulerStandardOrder: return static_cast(1); - default: throw py::value_error("Invalid MICM solver type."); - } - }, + micm.def("_vector_size", &musica::GetVectorSize, "Returns the vector dimension for vector-ordered solvers, 0 otherwise."); micm.def( diff --git a/src/micm/micm.cpp b/src/micm/micm.cpp index 75388af6a..6f35a30e3 100644 --- a/src/micm/micm.cpp +++ b/src/micm/micm.cpp @@ -33,6 +33,11 @@ namespace musica } } + MICM::MICM(std::string config_path, MICMSolver solver_type) + : MICM(ReadConfiguration(config_path), solver_type) + { + } + MICM::MICM(const Chemistry& chemistry, MICMSolver solver_type) { auto configure = [&](auto builder) diff --git a/src/micm/micm_c_interface.cpp b/src/micm/micm_c_interface.cpp index e0ebb170f..7db005b5d 100644 --- a/src/micm/micm_c_interface.cpp +++ b/src/micm/micm_c_interface.cpp @@ -157,4 +157,17 @@ namespace musica return micm->GetMaximumNumberOfGridCells(); } + std::size_t GetVectorSize(musica::MICMSolver solver_type) + { + switch (solver_type) + { + case musica::MICMSolver::Rosenbrock: + case musica::MICMSolver::BackwardEuler: + case musica::MICMSolver::CudaRosenbrock: return musica::MUSICA_VECTOR_SIZE; + case musica::MICMSolver::RosenbrockStandardOrder: + case musica::MICMSolver::BackwardEulerStandardOrder: return static_cast(1); + default: throw std::runtime_error("Invalid MICM solver type."); + } + } + } // namespace musica \ No newline at end of file diff --git a/src/micm/state.cpp b/src/micm/state.cpp index 2ff9aefe2..e97550f99 100644 --- a/src/micm/state.cpp +++ b/src/micm/state.cpp @@ -5,6 +5,7 @@ // solver configurations, leveraging both vector-ordered and standard-ordered state type. // It also includes functions for creating and deleting State instances with c bindings. #include +#include #include #include @@ -52,6 +53,102 @@ namespace musica state_variant_); } + void State::SetConcentrations(const std::map>& input, musica::MICMSolver solver_type) + { + std::visit( + [&](auto& st) + { + std::size_t vector_size_ = musica::GetVectorSize(solver_type); + size_t n_species = st.variable_map_.size(); + for (const auto& [name, values] : input) + { + auto it = st.variable_map_.find(name); + if (it == st.variable_map_.end()) + continue; + size_t i_species = it->second; + for (size_t i_cell = 0; i_cell < values.size(); ++i_cell) + { + size_t group_index = i_cell / vector_size_; + size_t row_in_group = i_cell % vector_size_; + size_t idx = (group_index * n_species + i_species) * vector_size_ + row_in_group; + st.variables_.AsVector()[idx] = values[i_cell]; + } + } + }, + state_variant_); + } + + std::map> State::GetConcentrations(musica::MICMSolver solver_type) const + { + return std::visit( + [&](auto& st) + { + std::size_t vector_size_ = musica::GetVectorSize(solver_type); + std::map> output; + size_t n_species = st.variable_map_.size(); + for (const auto& [name, i_species] : st.variable_map_) + { + output[name] = std::vector(st.NumberOfGridCells()); + for (size_t i_cell = 0; i_cell < st.NumberOfGridCells(); ++i_cell) + { + size_t group_index = i_cell / vector_size_; + size_t row_in_group = i_cell % vector_size_; + size_t idx = (group_index * n_species + i_species) * vector_size_ + row_in_group; + output[name][i_cell] = st.variables_.AsVector()[idx]; + } + } + return output; + }, + state_variant_); + } + + void State::SetRateConstants(const std::map>& input, musica::MICMSolver solver_type) { + std::visit( + [&](auto& st) + { + std::size_t vector_size_ = musica::GetVectorSize(solver_type); + size_t n_params = st.custom_rate_parameter_map_.size(); + for (const auto& [name, values] : input) + { + auto it = st.custom_rate_parameter_map_.find(name); + if (it == st.custom_rate_parameter_map_.end()) + continue; + size_t i_param = it->second; + for (size_t i_cell = 0; i_cell < values.size(); ++i_cell) + { + size_t group_index = i_cell / vector_size_; + size_t row_in_group = i_cell % vector_size_; + size_t idx = (group_index * n_params + i_param) * vector_size_ + row_in_group; + st.custom_rate_parameters_.AsVector()[idx] = values[i_cell]; + } + } + }, + state_variant_); + } + + std::map> State::GetRateConstants(musica::MICMSolver solver_type) const { + return std::visit( + [&](auto& st) + { + std::size_t vector_size_ = musica::GetVectorSize(solver_type); + std::map> output; + size_t n_params = st.custom_rate_parameter_map_.size(); + for (const auto& [name, i_param] : st.custom_rate_parameter_map_) + { + output[name] = std::vector(st.NumberOfGridCells()); + for (size_t i_cell = 0; i_cell < st.NumberOfGridCells(); ++i_cell) + { + size_t group_index = i_cell / vector_size_; + size_t row_in_group = i_cell % vector_size_; + size_t idx = (group_index * n_params + i_param) * vector_size_ + row_in_group; + output[name][i_cell] = st.custom_rate_parameters_.AsVector()[idx]; + } + } + return output; + }, + state_variant_); + } + std::vector& State::GetOrderedConcentrations() { return std::visit([](auto& st) -> std::vector& { return st.variables_.AsVector(); }, state_variant_); diff --git a/src/test/unit/micm/micm_wrapper.cpp b/src/test/unit/micm/micm_wrapper.cpp index 640b88ff8..a75f05d50 100644 --- a/src/test/unit/micm/micm_wrapper.cpp +++ b/src/test/unit/micm/micm_wrapper.cpp @@ -1,4 +1,5 @@ #include +#include #include #include