Skip to content

Commit

Permalink
Some improvements to GeometryFeatures and TemporalFeatures. A new tes…
Browse files Browse the repository at this point in the history
…t for confirming that a reverse FFT applied to the FFT output reproduces the original signal.
  • Loading branch information
ashesfall committed Oct 5, 2024
1 parent 02b6403 commit 4374514
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,20 @@ default <T extends PackedCollection<?>> CollectionProducer<T> tanh(Supplier<Eval
null, (Producer) input);
}


default CollectionProducer<PackedCollection<?>> sinw(Producer<PackedCollection<?>> input,
Producer<PackedCollection<?>> wavelength,
Producer<PackedCollection<?>> amp) {
return sin(c(TWO_PI).multiply(input).divide(wavelength)).multiply(amp);
}

default CollectionProducer<PackedCollection<?>> sinw(Producer<PackedCollection<?>> input,
Producer<PackedCollection<?>> wavelength,
Producer<PackedCollection<?>> phase,
Producer<PackedCollection<?>> amp) {
return sin(c(TWO_PI).multiply(divide(input, wavelength).subtract(phase))).multiply(amp);
}

@Deprecated
default ExpressionComputation relativeSin(Supplier<Evaluable<? extends PackedCollection<?>>> input) {
Function<List<ArrayVariable<Double>>, Expression<Double>> exp = args ->
Expand All @@ -79,16 +93,16 @@ default ExpressionComputation relativeCos(Supplier<Evaluable<? extends PackedCol
return new ExpressionComputation(List.of(exp), input);
}

default CollectionProducer<PackedCollection<?>> sinw(Producer<PackedCollection<?>> input,
Producer<PackedCollection<?>> wavelength,
Producer<PackedCollection<?>> amp) {
default CollectionProducer<PackedCollection<?>> relativeSinw(Producer<PackedCollection<?>> input,
Producer<PackedCollection<?>> wavelength,
Producer<PackedCollection<?>> amp) {
return relativeSin(c(TWO_PI).multiply(input).divide(wavelength)).multiply(amp);
}

default CollectionProducer<PackedCollection<?>> sinw(Producer<PackedCollection<?>> input,
Producer<PackedCollection<?>> wavelength,
Producer<PackedCollection<?>> phase,
Producer<PackedCollection<?>> amp) {
default CollectionProducer<PackedCollection<?>> relativeSinw(Producer<PackedCollection<?>> input,
Producer<PackedCollection<?>> wavelength,
Producer<PackedCollection<?>> phase,
Producer<PackedCollection<?>> amp) {
return relativeSin(c(TWO_PI).multiply(divide(input, wavelength).subtract(phase))).multiply(amp);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,12 @@ default Interpolate interpolate(
}

default FourierTransform fft(int bins, Producer<PackedCollection<?>> input, ComputeRequirement... requirements) {
FourierTransform fft = new FourierTransform(bins, input);
return fft(bins, false, input, requirements);
}

default FourierTransform fft(int bins, boolean inverse,
Producer<PackedCollection<?>> input, ComputeRequirement... requirements) {
FourierTransform fft = new FourierTransform(bins, inverse, input);
if (requirements.length > 0) fft.setComputeRequirements(List.of(requirements));
return fft;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@

import io.almostrealism.code.ComputeRequirement;
import org.almostrealism.collect.PackedCollection;
import org.almostrealism.time.Frequency;
import org.almostrealism.time.computations.FourierTransform;
import org.almostrealism.util.TestFeatures;
import org.junit.Assert;
import org.junit.Test;

import java.util.List;
Expand Down Expand Up @@ -48,4 +50,46 @@ public void compile(ComputeRequirement requirement) {
ft.setComputeRequirements(List.of(requirement));
ft.get().evaluate();
}

@Test
public void forwardAndInverse() {
forwardAndInverse(ComputeRequirement.CPU);
}

public void forwardAndInverse(ComputeRequirement requirement) {
int bins = 1024;
Frequency f1 = new Frequency(440.00);
Frequency f2 = new Frequency(587.33);

PackedCollection<?> input = new PackedCollection<>(2, bins);

a(cp(input.range(shape(bins)).each()),
add(
sinw(integers(0, bins), c(f1.getWaveLength()), c(0.9)),
sinw(integers(0, bins), c(f2.getWaveLength()), c(0.6))))
.get().run();

FourierTransform ft = fft(bins, cp(input), requirement);
PackedCollection<?> out = ft.get().evaluate();
log(out.getShape());

FourierTransform ift = fft(bins, true, cp(out), requirement);
PackedCollection<?> reversed = ift.get().evaluate();
log(reversed.getShape());

int total = 0;

for (int i = 0; i < bins; i++) {
double expected = bins * input.valueAt(0, i);

if (expected != 0) {
double actual = reversed.toDouble(i);
log(expected + " vs " + actual);
assertSimilar(expected, actual, 0.0001);
total++;
}
}

Assert.assertTrue(total > (bins * 0.9));
}
}

0 comments on commit 4374514

Please sign in to comment.