Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions include/musica/micm/micm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

#include <musica/micm/chemistry.hpp>
#include <musica/micm/parse.hpp>
#include <musica/micm/state.hpp>

#include <micm/CPU.hpp>
#ifdef MUSICA_ENABLE_CUDA
Expand Down Expand Up @@ -45,7 +44,7 @@ namespace
class MusicaErrorCategory : public std::error_category
{
public:
const char *name() const noexcept override
const char* name() const noexcept override
{
return MUSICA_ERROR_CATEGORY;
}
Expand Down Expand Up @@ -74,6 +73,7 @@ inline std::error_code make_error_code(MusicaErrCode e)

namespace musica
{
class State; // forward declaration to break circular include
/// @brief Types of MICM solver
enum MICMSolver
{
Expand Down Expand Up @@ -106,7 +106,8 @@ namespace musica
public:
SolverVariant solver_variant_;

MICM(const Chemistry &chemistry, MICMSolver solver_type);
MICM(const Chemistry& chemistry, MICMSolver solver_type);
MICM(std::string config_path, MICMSolver solver_type);
MICM() = default;
~MICM()
{
Expand All @@ -116,27 +117,27 @@ namespace musica
// cuda must clean all of its runtime resources
// Otherwise, we risk the CudaRosenbrock destructor running after
// the cuda runtime has closed
std::visit([](auto &solver) { solver.reset(); }, solver_variant_);
std::visit([](auto& solver) { solver.reset(); }, solver_variant_);
micm::cuda::CudaStreamSingleton::GetInstance().CleanUp();
#endif
}

/// @brief Solve the system
/// @param state Pointer to state object
/// @param time_step Time [s] to advance the state by
micm::SolverResult Solve(musica::State *state, double time_step);
micm::SolverResult Solve(musica::State* state, double time_step);

/// @brief Get a property for a chemical species
/// @param species_name Name of the species
/// @param property_name Name of the property
/// @return Value of the property
template<class T>
T GetSpeciesProperty(const std::string &species_name, const std::string &property_name)
T GetSpeciesProperty(const std::string& species_name, const std::string& property_name)
{
micm::System system = std::visit([](auto &solver) -> micm::System { return solver->GetSystem(); }, solver_variant_);
for (const auto &phase_species : system.gas_phase_.phase_species_)
micm::System system = std::visit([](auto& solver) -> micm::System { return solver->GetSystem(); }, solver_variant_);
for (const auto& phase_species : system.gas_phase_.phase_species_)
{
const auto &species = phase_species.species_;
const auto& species = phase_species.species_;
if (species.name_ == species_name)
{
return species.GetProperty<T>(property_name);
Expand All @@ -149,7 +150,7 @@ namespace musica
/// @return Maximum number of grid cells
std::size_t GetMaximumNumberOfGridCells()
{
return std::visit([](auto &solver) { return solver->MaximumNumberOfGridCells(); }, solver_variant_);
return std::visit([](auto& solver) { return solver->MaximumNumberOfGridCells(); }, solver_variant_);
}
};

Expand Down
4 changes: 4 additions & 0 deletions include/musica/micm/micm_c_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ namespace musica
size_t GetMaximumNumberOfGridCells(MICM *micm);

bool _IsCudaAvailable(Error *error);

/// @brief Get the MUSICA vector size
/// @return The MUSICA vector size
std::size_t GetVectorSize(musica::MICMSolver);
#ifdef __cplusplus
}
#endif
Expand Down
18 changes: 18 additions & 0 deletions include/musica/micm/state.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,24 @@ namespace musica
/// @param concentrations Vector of concentrations
void SetOrderedConcentrations(const std::vector<double>& concentrations);

/// @brief Set the concentrations from a map of species name to concentration vectors
/// @param input a mapping of species name to concentrations per grid cell
/// @param solver_type The solver type to use for ordering
void SetConcentrations(const std::map<std::string, std::vector<double>>& input, musica::MICMSolver solver_type);

/// @brief Get the concentrations as a map of species name to concentration vectors
/// @return Map of species name to concentration vectors
std::map<std::string, std::vector<double>> GetConcentrations(musica::MICMSolver solver_type) const;

/// @brief Set the rate constants from a map of species name to rate constant vectors
/// @param input a mapping of species name to rate constants per grid cell
/// @param solver_type The solver type to use for ordering
void SetRateConstants(const std::map<std::string, std::vector<double>>& input, musica::MICMSolver solver_type);

/// @brief Get the rate constants as a map of species name to rate constant vectors
/// @return Map of species name to rate constant vectors
std::map<std::string, std::vector<double>> GetRateConstants(musica::MICMSolver solver_type) const;

/// @brief Get the vector of rate constants
/// @return Vector of doubles
std::vector<double>& GetOrderedRateParameters();
Expand Down
8 changes: 2 additions & 6 deletions javascript/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
project(musica-wasm)

# Set C++ standard (required for MUSICA headers)
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

# Building WASM module
add_executable(musica-wasm
src/musica_wasm.cpp
src/micm/state_wrapper.cpp
src/micm/micm_wrapper.cpp
)

target_compile_features(musica-wasm PUBLIC cxx_std_20)

# Link against MUSICA
target_link_libraries(musica-wasm musica::musica)

Expand Down
2 changes: 1 addition & 1 deletion javascript/micm/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
export { MICM } from './micm.js';
export { State } from './state.js';
export { Conditions } from './conditions.js';
export { SolverType } from './solver.js';
export { SolverType, toWasmSolverType } from './solver.js';
export { SolverState, SolverStats, SolverResult } from './solver_result.js';
export { GAS_CONSTANT, AVOGADRO, BOLTZMANN } from './utils.js';
11 changes: 6 additions & 5 deletions javascript/micm/micm.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { State } from './state.js';
import { SolverType } from './solver.js';
import { SolverType, toWasmSolverType } from './solver.js';
import { SolverStats, SolverResult } from './solver_result.js';
import { getBackend } from '../backend.js';

Expand All @@ -18,6 +18,7 @@ export class MICM {

try {
const backend = getBackend();
const wasmSolver = toWasmSolverType(solverType);
// In Node.js with NODEFS mounted, configuration files are exposed under /host.
// In browser environments, this prefix is invalid, so only add it when running under Node.
let resolvedConfigPath = configPath;
Expand All @@ -28,7 +29,7 @@ export class MICM {
if (isNodeEnv && !configPath.startsWith('/host/')) {
resolvedConfigPath = `/host/${configPath}`;
}
const nativeMICM = backend.MICM.fromConfigPath(resolvedConfigPath, solverType);
const nativeMICM = backend.MICM.fromConfigPath(resolvedConfigPath, wasmSolver);
return new MICM(nativeMICM, solverType);
} catch (error) {
throw new Error(`Failed to create MICM solver from config path: ${error.message}`);
Expand All @@ -49,10 +50,11 @@ export class MICM {

try {
const backend = getBackend();
const wasmSolver = toWasmSolverType(solverType);
const mechanismJSON = mechanism.getJSON();
const jsonString = JSON.stringify(mechanismJSON);

const nativeMICM = backend.MICM.fromConfigString(jsonString, solverType);
const nativeMICM = backend.MICM.fromConfigString(jsonString, wasmSolver);
return new MICM(nativeMICM, solverType);
} catch (error) {
throw new Error(`Failed to create MICM solver from mechanism: ${error.message}`);
Expand All @@ -76,8 +78,7 @@ export class MICM {
if (numberOfGridCells <= 0) {
throw new RangeError('number_of_grid_cells must be greater than 0');
}
const nativeState = this._nativeMICM.createState(numberOfGridCells);
return new State(nativeState);
return new State(this._nativeMICM, numberOfGridCells, this._solverType);
}

solve(state, timeStep) {
Expand Down
47 changes: 46 additions & 1 deletion javascript/micm/solver.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,51 @@
import { getBackend } from '../backend.js';

/**
* Enum for solver types
* @readonly
* @enum {number}
*/
export const SolverType = {
rosenbrock: 1, // Vector-ordered Rosenbrock solver
rosenbrock_standard_order: 2, // Standard-ordered Rosenbrock solver
backward_euler: 3, // Vector-ordered BackwardEuler solver
backward_euler_standard_order: 4, // Standard-ordered BackwardEuler solver
};
};

/**
* Converts a solver type to the WASM SolverType enum.
* Accepts either a number (1-5) or a WASM enum value.
* Throws if input is invalid.
*
* @param {number|any} solverType - Numeric solver type or WASM enum
* @returns {any} WASM SolverType enum value
*/
export function toWasmSolverType(solverType) {
if (solverType === undefined || solverType === null) {
throw new TypeError('solverType is required');
}

const backend = getBackend();
const WASMEnum = backend.SolverType;

if (typeof solverType === 'number') {
switch (solverType) {
case 1: return WASMEnum.Rosenbrock;
case 2: return WASMEnum.RosenbrockStandardOrder;
case 3: return WASMEnum.BackwardEuler;
case 4: return WASMEnum.BackwardEulerStandardOrder;
case 5: return WASMEnum.CudaRosenbrock;
default:
throw new RangeError(`Invalid numeric solverType: ${solverType}`);
}
}

// If it’s already one of the WASM enum values, accept it
for (const key of Object.keys(WASMEnum)) {
if (solverType === WASMEnum[key]) {
return solverType;
}
}

throw new TypeError('solverType must be a valid WASM MICMSolver enum or number');
}
100 changes: 53 additions & 47 deletions javascript/micm/state.js
Original file line number Diff line number Diff line change
@@ -1,82 +1,88 @@
import { isScalarNumber } from './utils.js';
import { getBackend } from '../backend.js';
import { toWasmSolverType } from './solver.js';
import { GAS_CONSTANT } from './utils.js';

export class State {
constructor(nativeState) {
this._nativeState = nativeState;
constructor(nativeMICM, numberOfGridCells, solverType) {
if (!nativeMICM) {
throw new TypeError('nativeMICM is required');
}
if (numberOfGridCells < 1) {
throw new RangeError('number_of_grid_cells must be greater than 0');
}

const backend = getBackend();
this._nativeState = backend.create_state(nativeMICM, numberOfGridCells);

this._numberOfGridCells = numberOfGridCells;
this._solverType = toWasmSolverType(solverType);
}

setConcentrations(concentrations) {
// Convert to format expected by WASM
const formatted = {};
for (const [name, value] of Object.entries(concentrations)) {
formatted[name] = isScalarNumber(value) ? [value] : value;
}
this._nativeState.setConcentrations(formatted);
this._nativeState.set_concentrations(formatted, this._solverType);
}

getConcentrations() {
return this._nativeState.getConcentrations();
return this._nativeState.get_concentrations(this._solverType);
}

setUserDefinedRateParameters(params) {
const formatted = {};
for (const [name, value] of Object.entries(params)) {
formatted[name] = isScalarNumber(value) ? [value] : value;
}
this._nativeState.setUserDefinedRateParameters(formatted);
this._nativeState.set_user_defined_constants(formatted, this._solverType);
}

getUserDefinedRateParameters() {
return this._nativeState.getUserDefinedRateParameters();
return this._nativeState.get_user_defined_constants(this._solverType);
}

setConditions({
temperatures = null,
pressures = null,
air_densities = null,
} = {}) {
const cond = {};
setConditions({ temperatures = null, pressures = null, airDensities = null } = {}) {
const backend = getBackend();
const vec = new backend.VectorConditions();

if (temperatures !== null) {
cond.temperatures = isScalarNumber(temperatures)
? [temperatures]
: temperatures;
}
if (pressures !== null) {
cond.pressures = isScalarNumber(pressures)
? [pressures]
: pressures;
}
if (air_densities !== null) {
cond.air_densities = isScalarNumber(air_densities)
? [air_densities]
: air_densities;
}
const expand = (param) => {
if (param === null || param === undefined) {
return Array(this._numberOfGridCells).fill(null);
} else if (!Array.isArray(param)) {
if (this._numberOfGridCells > 1) {
throw new Error("Scalar input requires a single grid cell");
}
return [param];
} else if (param.length !== this._numberOfGridCells) {
throw new Error(`Array input must have length ${this._numberOfGridCells}`);
}
return param;
};

this._nativeState.setConditions(cond);
}
const temps = expand(temperatures);
const pres = expand(pressures);
const dens = expand(airDensities);

getConditions() {
return this._nativeState.getConditions();
}

getSpeciesOrdering() {
return this._nativeState.getSpeciesOrdering();
}
for (let i = 0; i < this._numberOfGridCells; i++) {
const T = temps[i] !== null ? temps[i] : NaN;
const P = pres[i] !== null ? pres[i] : NaN;
let rho = dens[i] !== null ? dens[i] : (!Number.isNaN(T) && !Number.isNaN(P) ? P / (GAS_CONSTANT * T) : NaN);


getUserDefinedRateParametersOrdering() {
return this._nativeState.getUserDefinedRateParametersOrdering();
}
const cond = new backend.Condition(T, P, rho);
vec.push_back(cond);
}

getNumberOfGridCells() {
return this._nativeState.getNumberOfGridCells();
this._nativeState.set_conditions(vec);
}

concentrationStrides() {
return this._nativeState.concentrationStrides();
getConditions() {
return this._nativeState.get_conditions();
}

userDefinedRateParameterStrides() {
return this._nativeState.userDefinedRateParameterStrides();
getNumberOfGridCells() {
return this._numberOfGridCells;
}
}
}
Loading
Loading