Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Batch Normalization Layer modules #157

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

Spnetic-5
Copy link
Collaborator

Addresses #155

@milancurcic , I've included the structure of the batch normalization layer. Could you please review it and confirm whether I'm in the correct direction?

@milancurcic
Copy link
Member

Thank you @Spnetic-5, great. Yes, this is certainly in the correct direction. I haven't yet reviewed the forward and backward algorithms, let me know when it's ready for review.

Also, despite #156, we should be able to test this layer implementation on its own, without integrating it with the network. In other words, while #156 will be necessary for full integration, we can work on this implementation as a standalone layer before #156.

@Spnetic-5
Copy link
Collaborator Author

@milancurcic Please review the forward and backward pass implementations I've added based on my interpretation of the paper, also could you guide me on how we can test this layer?

@milancurcic
Copy link
Member

Thanks! Let's review it briefly on the call today.

We can test this layer independently by passing some small, known input, and comparing the result with the corresponding known output. This should be straightforward since the batchnorm operation is relatively simple (just normalization of data). The backward is the inverse operation, so as I understand it, we can pass the same expected output to recover the same expected input.

@milancurcic
Copy link
Member

@Spnetic-5 I just saw your message on Discourse, no problem; we'll proceed work on this PR as usual.

@milancurcic
Copy link
Member

See for example a program that tests forward and backward passes of the maxpool2d layer using known inputs and expected outputs:

https://github.com/modern-fortran/neural-fortran/blob/main/test/test_maxpool2d_layer.f90

We'd use the same approach to test a batchnorm layer.

@Spnetic-5
Copy link
Collaborator Author

Hello, @milancurcic. Sorry for the lack of activity over the past few days; this was my final week of internship in Canada, and I'll be returning to India on Monday.

I added a test module for the batch norm layer, however it has some error; I believe I will need your assistance on this.

@milancurcic
Copy link
Member

No worries at all, thanks for all the work. I'll review it tomorrow.

Comment on lines +17 to +18
allocate(res % input(num_features, num_features))
allocate(res % output(num_features, num_features))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the shape for these should be (num_features, num_features), but rather (batch_size, num_features). The batch_size also won't be known until the first forward pass, so we should defer the allocation until then. In the forward pass, we could have a simple allocated check to see if they have not been allocated then, and allocate them to the shape of the input.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I meant (num_features, batch_size).

Comment on lines +51 to +52
normalized_input => (input - reshape(self % running_mean, shape(input, 1))) &
/ sqrt(reshape(self % running_var, shape(input, 1)) + self % epsilon) &
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

running_mean and running_var are not yet updated anywhere, only initialized.

sample_input = 1.0
gradient = 2.0

!TODO run forward and backward passes directly on the batchnorm_layer instance
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add a simple test directly on the batchnorm_layer instance rather than the high-level layer_type instance so you get the idea how it will work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants