Skip to content

Commit

Permalink
Very rough port for matrix neural net example
Browse files Browse the repository at this point in the history
Will need a lot of tweaks to turn this into something less messy
  • Loading branch information
Skeletonxf committed Feb 12, 2024
1 parent 9abd524 commit 82af6c8
Showing 1 changed file with 121 additions and 39 deletions.
160 changes: 121 additions & 39 deletions src/neural_networks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ compiler gets confused trying to infer the type.
```
use easy_ml::matrices::Matrix;
use easy_ml::matrices::views::{MatrixRange, MatrixView, MatrixRef, NoInteriorMutability};
use easy_ml::numeric::{Numeric, NumericRef};
use easy_ml::numeric::extra::{Real, RealRef, Exp};
use easy_ml::differentiation::{Record, WengertList};
use easy_ml::differentiation::{Record, RecordMatrix, WengertList, Index};
use rand::{Rng, SeedableRng};
use rand::distributions::Standard;
Expand Down Expand Up @@ -65,14 +66,30 @@ where for<'a> &'a T: NumericRef<T> + RealRef<T> {
(((input * w1).map(sigmoid) * w2).map(sigmoid) * w3).scalar()
}
fn model_2<'a, I>(
input: &RecordMatrix<'a, f32, I>,
w1: &RecordMatrix<'a, f32, Matrix<(f32, Index)>>,
w2: &RecordMatrix<'a, f32, Matrix<(f32, Index)>>,
w3: &RecordMatrix<'a, f32, Matrix<(f32, Index)>>
) -> Record<'a, f32>
where
I: MatrixRef<(f32, Index)> + NoInteriorMutability,
{
(((input * w1).map(sigmoid).unwrap() * w2).map(sigmoid).unwrap() * w3).get_as_record(0, 0)
}
/**
* Computes mean squared loss of the network against all the training data.
*
* This is written for a generic type, so it can be used with records and also
* with normal floats.
*/
fn mean_squared_loss<T: Numeric + Real + Copy>(
inputs: &Vec<Matrix<T>>, w1: &Matrix<T>, w2: &Matrix<T>, w3: &Matrix<T>, labels: &Vec<T>
fn mean_squared_loss<T: Numeric + Real + Copy, I>(
inputs: &Vec<Matrix<T>>,
w1: &Matrix<T>,
w2: &Matrix<T>,
w3: &Matrix<T>,
labels: &Vec<T>
) -> T
where for<'a> &'a T: NumericRef<T> + RealRef<T> {
inputs.iter().enumerate().fold(T::zero(), |acc, (i, input)| {
Expand All @@ -83,18 +100,48 @@ where for<'a> &'a T: NumericRef<T> + RealRef<T> {
}) / T::from_usize(inputs.len()).unwrap()
}
fn mean_squared_loss_2<'a>(
inputs: &RecordMatrix<'a, f32, Matrix<(f32, Index)>>,
w1: &mut RecordMatrix<'a, f32, Matrix<(f32, Index)>>,
w2: &mut RecordMatrix<'a, f32, Matrix<(f32, Index)>>,
w3: &mut RecordMatrix<'a, f32, Matrix<(f32, Index)>>,
labels: &RecordMatrix<'a, f32, Matrix<(f32, Index)>>,
) -> Record<'a, f32>
{
let rows = inputs.rows();
let columns = inputs.columns();
let history = w1.history();
let mut loss = Record::constant(0.0);
for r in 0..rows {
let input = RecordMatrix::from_iter(
(1, columns),
inputs.iter_row_major_as_records().skip(columns * r).take(columns)
).expect("Splitting inputs into RecordMatrix for each row should match expected size");
let correct = labels.get_as_record(0, r);
let output = model_2(&input, w1, w2, w3);
// sum up the squared loss
loss = loss + ((correct - output) * (correct - output));
}
loss / (rows as f32)
}
/**
* Updates the weight matrices to step the gradient by one step.
*
* Note that here we are no longer generic over the type, we need the methods
* defined on Record to do backprop.
*/
fn step_gradient(
inputs: &Vec<Matrix<Record<f32>>>,
w1: &mut Matrix<Record<f32>>, w2: &mut Matrix<Record<f32>>, w3: &mut Matrix<Record<f32>>,
labels: &Vec<Record<f32>>, learning_rate: f32, list: &WengertList<f32>
) -> f32 {
let loss = mean_squared_loss::<Record<f32>>(inputs, w1, w2, w3, labels);
fn step_gradient<'a>(
inputs: &RecordMatrix<'a, f32, Matrix<(f32, Index)>>,
w1: &mut RecordMatrix<'a, f32, Matrix<(f32, Index)>>,
w2: &mut RecordMatrix<'a, f32, Matrix<(f32, Index)>>,
w3: &mut RecordMatrix<'a, f32, Matrix<(f32, Index)>>,
labels: &RecordMatrix<'a, f32, Matrix<(f32, Index)>>,
learning_rate: f32,
list: &'a WengertList<f32>
) -> f32
{
let loss = mean_squared_loss_2(inputs, w1, w2, w3, labels);
let derivatives = loss.derivatives();
// update each element in the weight matrices by the derivatives
w1.map_mut(|x| x - (derivatives[&x] * learning_rate));
Expand All @@ -115,33 +162,51 @@ let mut random_generator = rand_chacha::ChaCha8Rng::seed_from_u64(25);
// randomly initalise the weights using the fixed seed generator for reproducibility
let list = WengertList::new();
// w1 will be a 3x3 matrix
let mut w1 = Matrix::from(vec![
n_random_numbers(&mut random_generator, 3),
n_random_numbers(&mut random_generator, 3),
n_random_numbers(&mut random_generator, 3)
]).map(|x| Record::variable(x, &list));
let mut w1 = RecordMatrix::variables(
&list,
Matrix::from(
vec![
n_random_numbers(&mut random_generator, 3),
n_random_numbers(&mut random_generator, 3),
n_random_numbers(&mut random_generator, 3)
]
)
);
// w2 will be a 3x3 matrix
let mut w2 = Matrix::from(vec![
n_random_numbers(&mut random_generator, 3),
n_random_numbers(&mut random_generator, 3),
n_random_numbers(&mut random_generator, 3)
]).map(|x| Record::variable(x, &list));
let mut w2 = RecordMatrix::variables(
&list,
Matrix::from(
vec![
n_random_numbers(&mut random_generator, 3),
n_random_numbers(&mut random_generator, 3),
n_random_numbers(&mut random_generator, 3)
]
)
);
// w3 will be a 3x1 column matrix
let mut w3 = Matrix::column(n_random_numbers(&mut random_generator, 3))
.map(|x| Record::variable(x, &list));
let mut w3 = RecordMatrix::variables(
&list,
Matrix::column(n_random_numbers(&mut random_generator, 3))
);
println!("w1 {}", w1);
println!("w2 {}", w2);
println!("w3 {}", w3);
// define XOR inputs, with biases added to the inputs
let inputs = vec![
Matrix::row(vec![ 0.0, 0.0, 1.0 ]).map(|x| Record::constant(x)),
Matrix::row(vec![ 0.0, 1.0, 1.0 ]).map(|x| Record::constant(x)),
Matrix::row(vec![ 1.0, 0.0, 1.0 ]).map(|x| Record::constant(x)),
Matrix::row(vec![ 1.0, 1.0, 1.0 ]).map(|x| Record::constant(x))
];
let inputs = RecordMatrix::constants(
Matrix::from(
vec![
vec![ 0.0, 0.0, 1.0 ],
vec![ 0.0, 1.0, 1.0 ],
vec![ 1.0, 0.0, 1.0 ],
vec![ 1.0, 1.0, 1.0 ],
]
)
);
// define XOR outputs which will be used as labels
let labels = vec![ 0.0, 1.0, 1.0, 0.0 ].into_iter().map(|x| Record::constant(x)).collect();
let labels = RecordMatrix::constants(
Matrix::row(vec![ 0.0, 1.0, 1.0, 0.0 ])
);
let learning_rate = 0.2;
let epochs = 4000;
Expand Down Expand Up @@ -173,21 +238,38 @@ println!("w1 {}", w1);
println!("w2 {}", w2);
println!("w3 {}", w3);
// check that the network has learned XOR properly
println!("0 0: {:?}", model::<Record<f32>>(&inputs[0], &w1, &w2, &w3).number);
println!("0 1: {:?}", model::<Record<f32>>(&inputs[1], &w1, &w2, &w3).number);
println!("1 0: {:?}", model::<Record<f32>>(&inputs[2], &w1, &w2, &w3).number);
println!("1 1: {:?}", model::<Record<f32>>(&inputs[3], &w1, &w2, &w3).number);
let row_1 = RecordMatrix::from_existing(
Some(&list),
MatrixView::from(MatrixRange::from(&inputs, 0..1, 0..3))
);
let row_2 = RecordMatrix::from_existing(
Some(&list),
MatrixView::from(MatrixRange::from(&inputs, 1..2, 0..3))
);
let row_3 = RecordMatrix::from_existing(
Some(&list),
MatrixView::from(MatrixRange::from(&inputs, 2..3, 0..3))
);
let row_4 = RecordMatrix::from_existing(
Some(&list),
MatrixView::from(MatrixRange::from(&inputs, 3..4, 0..3))
);
println!("0 0: {:?}", model_2(&row_1, &w1, &w2, &w3).number);
println!("0 1: {:?}", model_2(&row_2, &w1, &w2, &w3).number);
println!("1 0: {:?}", model_2(&row_3, &w1, &w2, &w3).number);
println!("1 1: {:?}", model_2(&row_4, &w1, &w2, &w3).number);
assert!(losses[epochs - 1] < 0.02);
// we can also extract the learned weights once done with training and avoid the memory
// overhead of Record
let w1_final = w1.map(|x| x.number);
let w2_final = w2.map(|x| x.number);
let w3_final = w3.map(|x| x.number);
println!("0 0: {:?}", model::<f32>(&inputs[0].map(|x| x.number), &w1_final, &w2_final, &w3_final));
println!("0 1: {:?}", model::<f32>(&inputs[1].map(|x| x.number), &w1_final, &w2_final, &w3_final));
println!("1 0: {:?}", model::<f32>(&inputs[2].map(|x| x.number), &w1_final, &w2_final, &w3_final));
println!("1 1: {:?}", model::<f32>(&inputs[3].map(|x| x.number), &w1_final, &w2_final, &w3_final));
let w1_final = w1.view().map(|(x, _)| x);
let w2_final = w2.view().map(|(x, _)| x);
let w3_final = w3.view().map(|(x, _)| x);
println!("0 0: {:?}", model::<f32>(&row_1.view().map(|(x, _)| x), &w1_final, &w2_final, &w3_final));
println!("0 1: {:?}", model::<f32>(&row_2.view().map(|(x, _)| x), &w1_final, &w2_final, &w3_final));
println!("1 0: {:?}", model::<f32>(&row_3.view().map(|(x, _)| x), &w1_final, &w2_final, &w3_final));
println!("1 1: {:?}", model::<f32>(&row_4.view().map(|(x, _)| x), &w1_final, &w2_final, &w3_final));
```
## Tensor APIs
Expand Down

0 comments on commit 82af6c8

Please sign in to comment.