We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
作者您好,我想要调用代码中的_local_pattern函数,center_nodes这个参数的取值应该是什么呢 def _local_pattern(self, center_nodes, r=0.1, r_resolution=100, phi_resolution=360): assert self._model_name in ['CLCRN','CLCSTN'], 'the model does not provide the kernel visualization' with torch.no_grad(): center_nodes = torch.from_numpy(np.array(center_nodes)).float().to(self._device) N = center_nodes.shape[0] angle_ratio = 1 / phi_resolution rs = np.linspace(0, r, r_resolution) phis = np.linspace(-np.pi, np.pi, phi_resolution) xs = torch.from_numpy(rs[:, None] * np.cos(phis)[None, :]).float().to(self._device).flatten() # r_res * phi_res ys = torch.from_numpy(rs[:, None] * np.sin(phis)[None, :]).float().to(self._device).flatten() # r_res * phi_res vs = torch.stack([xs, ys], dim=-1)[None, :, :].repeat(N, 1, 1)
kernel = self.model.get_kernel() local_pattern = kernel.kernel_prattern(center_nodes, vs, angle_ratio) return local_pattern, center_nodes, rs, phis
The text was updated successfully, but these errors were encountered:
No branches or pull requests
作者您好,我想要调用代码中的_local_pattern函数,center_nodes这个参数的取值应该是什么呢
def _local_pattern(self, center_nodes, r=0.1, r_resolution=100, phi_resolution=360):
assert self._model_name in ['CLCRN','CLCSTN'], 'the model does not provide the kernel visualization'
with torch.no_grad():
center_nodes = torch.from_numpy(np.array(center_nodes)).float().to(self._device)
N = center_nodes.shape[0]
angle_ratio = 1 / phi_resolution
rs = np.linspace(0, r, r_resolution)
phis = np.linspace(-np.pi, np.pi, phi_resolution)
xs = torch.from_numpy(rs[:, None] * np.cos(phis)[None, :]).float().to(self._device).flatten() # r_res * phi_res
ys = torch.from_numpy(rs[:, None] * np.sin(phis)[None, :]).float().to(self._device).flatten() # r_res * phi_res
vs = torch.stack([xs, ys], dim=-1)[None, :, :].repeat(N, 1, 1)
The text was updated successfully, but these errors were encountered: