Skip to content

Commit

Permalink
Refactor interpolation and yield calculation methods
Browse files Browse the repository at this point in the history
Simplified the procedure to use interpolators and fit in interpolation points. This overhaul has led to adjustments in multiple files, including the Interpolator trait and classes like NaturalCubic and CatmullRom. YieldCurve's calculation methods have also been updated to match this change. Additionally, the test files have been adjusted to follow the new way of generating interpolators with chains and fit calls.
  • Loading branch information
nakashima-hikaru committed Jun 12, 2024
1 parent b822935 commit ed8e078
Show file tree
Hide file tree
Showing 8 changed files with 27 additions and 41 deletions.
2 changes: 1 addition & 1 deletion crates/qlab-instrument/src/bond.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ impl<V: Value> Bond<V> {
///
/// # Errors
/// Error occurs if a discount factor calculation fails
pub fn discounted_value<D: DayCount, I: Interpolator<V>>(
pub fn discounted_value<D: DayCount, I: Interpolator<I, V>>(
&self,
bond_settle_date: Date,
yield_curve: &YieldCurve<D, V, I>,
Expand Down
4 changes: 2 additions & 2 deletions crates/qlab-math/src/interpolation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub(crate) struct Point<V> {
y: V,
}

pub trait Interpolator<V: Value> {
pub trait Interpolator<I, V: Value>: Default {
/// Fits the model to the given data points.
///
/// This function adjusts the parameters of the model to minimize the difference
Expand All @@ -24,7 +24,7 @@ pub trait Interpolator<V: Value> {
/// # Errors
///
/// Returns an error if the fitting process fails.
fn try_fit(&mut self, xs_and_ys: &[(V, V)]) -> Result<(), InterpolationError<V>>;
fn try_fit(self, xs_and_ys: &[(V, V)]) -> Result<I, InterpolationError<V>>;

/// Returns the value of type `V` and wraps it in a `QLabResult`.
///
Expand Down
6 changes: 3 additions & 3 deletions crates/qlab-math/src/interpolation/linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@ impl<V: Real> Linear<V> {
}
}

impl<V: Value> Interpolator<V> for Linear<V> {
fn try_fit(&mut self, raw_points: &[(V, V)]) -> Result<(), InterpolationError<V>> {
impl<V: Value> Interpolator<Linear<V>, V> for Linear<V> {
fn try_fit(mut self, raw_points: &[(V, V)]) -> Result<Self, InterpolationError<V>> {
let mut points = Vec::with_capacity(raw_points.len());
for &(x, y) in raw_points {
points.push(Point { x, y });
}
self.points = points;
Ok(())
Ok(self)
}

/// Calculates the value at time `t` using linear interpolation based on a grid of points.
Expand Down
9 changes: 4 additions & 5 deletions crates/qlab-math/src/interpolation/spline/catmull_rom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pub struct CatmullRom<V: Value> {
points: Vec<Point2<V>>,
}

impl<V: Value> Interpolator<V> for CatmullRom<V> {
impl<V: Value> Interpolator<CatmullRom<V>, V> for CatmullRom<V> {
/// Constructs a new `CatmullRom` from a slice of raw points.
///
/// # Arguments
Expand All @@ -31,7 +31,7 @@ impl<V: Value> Interpolator<V> for CatmullRom<V> {
/// * `InterpolationError::InsufficientPointsError(n)` - If the number of `raw_points` is less than 3, where `n` is the number of `raw_points`.
/// * `InterpolationError::PointOrderError` - If the x-coordinates of the `raw_points` are not in ascending order.
///
fn try_fit(&mut self, raw_points: &[(V, V)]) -> Result<(), InterpolationError<V>> {
fn try_fit(mut self, raw_points: &[(V, V)]) -> Result<Self, InterpolationError<V>> {
if raw_points.len() < 3 {
return Err(InterpolationError::InsufficientPointsError(
raw_points.len(),
Expand All @@ -48,7 +48,7 @@ impl<V: Value> Interpolator<V> for CatmullRom<V> {
points.push(point);
}
self.points = points;
Ok(())
Ok(self)
}
/// Tries to find the value `x` in the Hermite spline.
///
Expand Down Expand Up @@ -193,8 +193,7 @@ mod tests {
#[test]
fn test_f64() {
let points = [(0.0, 1.0), (0.5, 0.5), (1.0, 0.0)];
let mut interpolator = CatmullRom::default();
interpolator.try_fit(&points).unwrap();
let interpolator = CatmullRom::default().try_fit(&points).unwrap();
let val = interpolator.try_value(0.75).unwrap();
assert!((val - 0.270_833_333_333_333_37_f64).abs() < f64::EPSILON);
}
Expand Down
10 changes: 4 additions & 6 deletions crates/qlab-math/src/interpolation/spline/natural_cubic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pub struct NaturalCubic<V: Value> {
points: Vec<Point3<V>>,
}

impl<V: Value> Interpolator<V> for NaturalCubic<V> {
impl<V: Value> Interpolator<NaturalCubic<V>, V> for NaturalCubic<V> {
/// Tries to create a new `NaturalCubic` from the given raw points.
///
/// # Arguments
Expand All @@ -36,7 +36,7 @@ impl<V: Value> Interpolator<V> for NaturalCubic<V> {
///
/// # Panics
/// Will panic if `V` fail to cast constants.
fn try_fit(&mut self, raw_points: &[(V, V)]) -> Result<(), InterpolationError<V>> {
fn try_fit(mut self, raw_points: &[(V, V)]) -> Result<Self, InterpolationError<V>> {
if raw_points.len() < 3 {
return Err(InterpolationError::InsufficientPointsError(
raw_points.len(),
Expand Down Expand Up @@ -90,8 +90,7 @@ impl<V: Value> Interpolator<V> for NaturalCubic<V> {
}

self.points = points;

Ok(())
Ok(self)
}

/// Evaluates the Hermite spline at the given value `x`.
Expand Down Expand Up @@ -150,8 +149,7 @@ mod tests {
#[test]
fn test_f64() {
let points = [(0.0, 1.0), (0.5, 0.5), (1.0, 0.0)];
let mut interpolator = NaturalCubic::default();
interpolator.try_fit(&points).unwrap();
let interpolator = NaturalCubic::default().try_fit(&points).unwrap();
let val = interpolator.try_value(0.75).unwrap();
assert!((0.25_f64 - val) / 0.25_f64 < f64::EPSILON);
}
Expand Down
1 change: 1 addition & 0 deletions crates/qlab-math/src/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@ pub trait Value:
+ DivAssign
+ FromPrimitive
+ Real
+ Default
{
}
31 changes: 10 additions & 21 deletions crates/qlab-termstructure/src/yield_curve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ use std::marker::PhantomData;
/// A trait representing a yield curve with discount factor calculations.
///
/// The trait is generic over the type of Realing point values (`V`) and the day count convention (`D`).
pub struct YieldCurve<D: DayCount, V: Value, I: Interpolator<V>> {
pub struct YieldCurve<D: DayCount, V: Value, I: Interpolator<I, V>> {
settlement_date: Date,
interpolator: I,
_phantom: PhantomData<V>,
_day_count: PhantomData<D>,
}

impl<V: Value, D: DayCount, I: Interpolator<V>> YieldCurve<D, V, I> {
impl<V: Value, D: DayCount, I: Interpolator<I, V>> YieldCurve<D, V, I> {
/// Creates a new instance of the `QLab` struct.
///
/// # Arguments
Expand All @@ -33,12 +33,7 @@ impl<V: Value, D: DayCount, I: Interpolator<V>> YieldCurve<D, V, I> {
///
/// # Errors
/// Returns an `Err` variant if the lengths of `maturities` and `spot_yields` do not match.
pub fn new(
settlement_date: Date,
maturities: &[Date],
spot_yields: &[V],
mut interpolator: I,
) -> QLabResult<Self> {
pub fn new(settlement_date: Date, maturities: &[Date], spot_yields: &[V]) -> QLabResult<Self> {
if maturities.len() != spot_yields.len() {
return Err(
InvalidInput("maturities and spot_yields are different lengths".into()).into(),
Expand All @@ -53,7 +48,7 @@ impl<V: Value, D: DayCount, I: Interpolator<V>> YieldCurve<D, V, I> {
.copied()
.zip(spot_yields.iter().copied())
.collect();
interpolator.try_fit(&val)?;
let interpolator = I::default().try_fit(&val)?;
Ok(Self {
_phantom: PhantomData,
settlement_date,
Expand Down Expand Up @@ -122,9 +117,9 @@ mod tests {
#[derive(Default)]
struct Flat(f64);

impl Interpolator<f64> for Flat {
fn try_fit(&mut self, _x_and_y: &[(f64, f64)]) -> Result<(), InterpolationError<f64>> {
Ok(())
impl Interpolator<Flat, f64> for Flat {
fn try_fit(self, _x_and_y: &[(f64, f64)]) -> Result<Self, InterpolationError<f64>> {
Ok(self)
}

fn try_value(&self, _t: f64) -> Result<f64, InterpolationError<f64>> {
Expand All @@ -137,19 +132,13 @@ mod tests {
let settlement_date = Date::from_ymd(2022, 12, 31).unwrap();
let maturities = vec![Date::from_ymd(2022, 12, 31).unwrap()];
let spot_yields = vec![0.02]; // 2% yield
let interpolator = Flat(0.0);

let yield_curve = YieldCurve::<Act365, _, _>::new(
settlement_date,
&maturities,
&spot_yields,
interpolator,
)
.unwrap();
let yield_curve =
YieldCurve::<Act365, _, Flat>::new(settlement_date, &maturities, &spot_yields).unwrap();

let d1 = Date::from_ymd(2023, 1, 1).unwrap();
let d2 = Date::from_ymd(2023, 12, 31).unwrap();
let discount_factor = yield_curve.discount_factor(d1, d2).unwrap();
assert!((discount_factor - 1.0).abs() < f64::EPSILON);
assert!((discount_factor - 1.0_f64).abs() < f64::EPSILON);
}
}
5 changes: 2 additions & 3 deletions crates/qlab/tests/calculate_bond.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,8 @@ fn main() {
let spot_yields: Vec<f64> = vec![
0.02, 0.0219, 0.0237, 0.0267, 0.0312, 0.0343, 0.0378, 0.0393, 0.04, 0.0401, 0.0401, 0.04,
];
let interpolator = Linear::new();
let yield_curve: YieldCurve<Act365, _, _> =
YieldCurve::new(spot_settle_date, &maturities, &spot_yields, interpolator).unwrap();
let yield_curve: YieldCurve<Act365, _, Linear<f64>> =
YieldCurve::new(spot_settle_date, &maturities, &spot_yields).unwrap();
let val = bond_20_yr
.discounted_value(spot_settle_date, &yield_curve)
.unwrap();
Expand Down

0 comments on commit ed8e078

Please sign in to comment.