Skip to content

qgpmztmf/Stationary_Wavelet_Transform_PyTorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

21 Commits
 
 
 
 

Repository files navigation

Stationary Wavelet Transform PyTorch

This code provides support for computing the 2D stationary discrete wavelet and its inverse, and passing gradients through using pytorch. It is developed based on https://github.com/fbcotter/pytorch_wavelets and a supplement to that project.

How to use

import pywt
import torch

import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

J = 3
wave = 'db1'
mode='symmetric'

img_1 = pywt.data.camera()
img_2 = pywt.data.ascent()
img = np.stack([img_1, img_2], 0)

input = torch.tensor(img).reshape(2,1,512,512).float()

sfm = SWTForward(J, wave, mode)
ifm = SWTInverse(wave, mode)

coeffs = sfm(input)
recon = ifm(coeffs)

plt.subplot(2,2,1), plt.imshow(recon[0,0], cmap='gray')
plt.subplot(2,2,2), plt.imshow(recon[1,0], cmap='gray')

plt.subplot(2,2,3), plt.imshow(input[0,0], cmap='gray')
plt.subplot(2,2,4), plt.imshow(input[1,0], cmap='gray')

Results

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages