-
Notifications
You must be signed in to change notification settings - Fork 87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BCHW format #7
Comments
You can use transpose: |
But, transpose(1,2).transpose(2,3) seems not to rearrange the internal array. So if I run the code, at line 44 and 45 in my_lib.c,
xf is not valid, because grids_strideWidth is still 1. I guess it needs to be like
although I have not tested it. |
On a separate note, I guess BCHW should be the standard because it follows pytorch conv layers convention. I probably will have a version for that later. Let me know what you think. |
Oh sorry I misunderstood, you are talking about permutation for grid rather than image. Hmm, I always use the grid generator to generate grid in BHWC format directly so never run into the problem you mentioned. |
Thank you. Meanwhile I'll use permute (or transpose) and then contiguous(). It seems to work properly so far :) |
Thank Fei for the nice work. Do you have any update on BCHW support? |
@junyanz Hi Junyan, thank you for your interest, it is likely to be added after the NIPS deadline. We do find the majority of users need BCHW instead of BHWC and will thus prioritize it :D . |
@fxia22 Thanks for your prompt response. Good luck with your NIPS submission. |
to use, set `layout` argument to be 'BCHW' when initializing STN
BCHW support added. example can be found in |
Thanks a lot! |
@fxia22 to go from |
@edgarriba Thanks for your suggestion. As discussed above, the problem of |
ah, right. Permute just recompute strides |
Excellent work!
I would like to use this in the middle of my pytorch network, so my tensors are in [Batch x Channel x Height x Width] format. I tried to use torch.permute to change their dimension orders, but it was not successful.
For example, when a = torch.randn((2,3,4,5)), a.stride() is (60, 20, 5, 1), but
if I do b = a.permute((0,1,2,3)), b.stride() is (1, 60, 20, 5) while torch.randn(5,2,3,4).stride() is (24, 12, 4, 1).
Is there an easy and efficient way to do it? or do I need to change .c and .cu files in src?
I guess a.permute((0,1,2,3)).contiguous() might be a solution, but I'm not sure it is safe for Variable (autograd).
Thank you.
The text was updated successfully, but these errors were encountered: