forked from maxjiang93/space_time_pde
-
Notifications
You must be signed in to change notification settings - Fork 0
/
implicit_net_test.py
29 lines (22 loc) · 912 Bytes
/
implicit_net_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
"""Unit Test for implicit_net."""
# pylint: disable=import-error, no-member, too-many-arguments, no-self-use
import unittest
import numpy as np
import torch
from parameterized import parameterized
import implicit_net
class ImplicitNetTest(unittest.TestCase):
"""Unit test for implicit_net"""
@parameterized.expand((
[32, 2048, 4, 3, 32, 16],
))
def test_imnet(self, batch_size, npts, n_in, n_out, n_chan, n_filter):
"""unit test."""
input_coords = torch.rand(batch_size, npts, n_in)
input_chan = torch.rand(batch_size, npts, n_chan)
inputs = torch.cat([input_coords, input_chan], axis=-1)
model = implicit_net.ImNet(dim=n_in, in_features=n_chan, out_features=n_out, nf=n_filter)
out = model(inputs)
np.testing.assert_allclose(out.shape, [batch_size, npts, n_out], atol=1e-4)
if __name__ == '__main__':
unittest.main()