Skip to content

Commit

Permalink
Refactored env/pitch to avoid unsafe array indexing.
Browse files Browse the repository at this point in the history
Moved all the things we needed per channel into an "info" struct.
Changed a bunch of arrays into a single std::vector<info_t>.
Use std::vector<info_t>::at(index) to validate every memory access
into a reference 'q'. Use 'q' for all read/write to a given channel.
  • Loading branch information
cosinekitty committed Feb 2, 2025
1 parent 86f640e commit e483352
Showing 1 changed file with 69 additions and 50 deletions.
119 changes: 69 additions & 50 deletions src/env_pitch_detect.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,40 +4,69 @@

namespace Sapphire
{
template <typename value_t, int filterLayers>
struct EnvPitchChannelInfo
{
value_t prevSignal;
int ascendSamples; // wavelength sample counters between consecutive ascending zero-crossings
int descendSamples; // wavelength sample counters between consecutive descending zero-crossings
int rawWaveLengthAscend;
int rawWaveLengthDescend;
value_t filteredWaveLength;

using filter_t = StagedFilter<value_t, filterLayers>;
filter_t loCutFilter;
filter_t hiCutFilter;
filter_t jitterFilter;
filter_t amplFilter;

EnvPitchChannelInfo()
{
initialize();
}

void initialize()
{
prevSignal = 0;
ascendSamples = 0;
descendSamples = 0;
rawWaveLengthAscend = 0;
rawWaveLengthDescend = 0;
filteredWaveLength = 0;

// Reset all filters in case they went non-finite.
loCutFilter.Reset();
hiCutFilter.Reset();
jitterFilter.Reset();
amplFilter.Reset();
}
};


template <typename value_t, int maxChannels, int filterLayers = 3>
class EnvPitchDetector
{
private:
static_assert(maxChannels > 0);

float centerFrequencyHz = 261.6255653005986; // note C4 = 440 / (2**(3/4))

int currentSampleRate = 0;

value_t prevSignal[maxChannels];
int ascendSamples[maxChannels]; // wavelength sample counters between consecutive ascending zero-crossings
int descendSamples[maxChannels]; // wavelength sample counters between consecutive descending zero-crossings
int rawWaveLengthAscend[maxChannels];
int rawWaveLengthDescend[maxChannels];
value_t filteredWaveLength[maxChannels];

using filter_t = StagedFilter<value_t, filterLayers>;
filter_t loCutFilter[maxChannels];
filter_t hiCutFilter[maxChannels];
filter_t jitterFilter[maxChannels];
filter_t amplFilter[maxChannels];
float loCutFrequency = 20;
float hiCutFrequency = 3000;
float jitterCornerFrequency = 10;
float amplCornerFrequency = 10;
int recoveryCountdown; // how many samples remain before trying to filter again (CPU usage limiter)
int recoveryCountdown = 0; // how many samples remain before trying to filter again (CPU usage limiter)

using info_t = EnvPitchChannelInfo<value_t, filterLayers>;
std::vector<info_t> info;

value_t updateAmplitude(int channel, value_t signal)
{
// Square the signal and filter the result.
// This gives us a time-smeared measure of power.
amplFilter[channel].SetCutoffFrequency(amplCornerFrequency);
return amplFilter[channel].UpdateLoPass(signal*signal, currentSampleRate);
info_t& q = info.at(channel);
q.amplFilter.SetCutoffFrequency(amplCornerFrequency);
return q.amplFilter.UpdateLoPass(signal*signal, currentSampleRate);
}

void updateWaveLength(int channel, int wavelengthSamples)
Expand All @@ -59,34 +88,22 @@ namespace Sapphire
if (rawFrequencyHz < loCutFrequency || rawFrequencyHz > hiCutFrequency)
return;

jitterFilter[channel].SetCutoffFrequency(jitterCornerFrequency);
filteredWaveLength[channel] = jitterFilter[channel].UpdateLoPass(wavelengthSamples, currentSampleRate);
info_t& q = info.at(channel);
q.jitterFilter.SetCutoffFrequency(jitterCornerFrequency);
q.filteredWaveLength = info[channel].jitterFilter.UpdateLoPass(wavelengthSamples, currentSampleRate);
}

public:
EnvPitchDetector()
{
initialize();
info.resize(maxChannels);
}

void initialize()
{
recoveryCountdown = 0;
for (int c = 0; c < maxChannels; ++c)
{
prevSignal[c] = 0;
ascendSamples[c] = 0;
descendSamples[c] = 0;
rawWaveLengthAscend[c] = 0;
rawWaveLengthDescend[c] = 0;
filteredWaveLength[c] = 0;

// Reset all filters in case they went non-finite.
loCutFilter[c].Reset();
hiCutFilter[c].Reset();
jitterFilter[c].Reset();
amplFilter[c].Reset();
}
info[c].initialize();
}

void process(
Expand Down Expand Up @@ -125,20 +142,22 @@ namespace Sapphire

for (int c = 0; c < numChannels; ++c)
{
++ascendSamples[c];
++descendSamples[c];
info_t& q = info.at(c);

++q.ascendSamples;
++q.descendSamples;

// Feed through a bandpass filter that rejects DC and other frequencies below 20 Hz,
// and also rejects very high frequencies.

// Reject frequencies lower than we want to keep.
loCutFilter[c].SetCutoffFrequency(loCutFrequency);
value_t locut = loCutFilter[c].UpdateHiPass(inFrame[c], sampleRateHz);
q.loCutFilter.SetCutoffFrequency(loCutFrequency);
value_t locut = q.loCutFilter.UpdateHiPass(inFrame[c], sampleRateHz);

// Reject frequencies higher than we want to keep.
// The band-pass result is our signal to feed through envelope and pitch detection.
hiCutFilter[c].SetCutoffFrequency(hiCutFrequency);
value_t signal = hiCutFilter[c].UpdateLoPass(locut, sampleRateHz);
q.hiCutFilter.SetCutoffFrequency(hiCutFrequency);
value_t signal = q.hiCutFilter.UpdateLoPass(locut, sampleRateHz);

// Make sure we have a normal numeric value for our signal.
if (!std::isfinite(signal))
Expand All @@ -161,31 +180,31 @@ namespace Sapphire
// Find (both ascending and descending), independently for each channel.
// Measure the interval between them, expressed in samples, called "wavelength".

if (signal * prevSignal[c] < 0)
if (signal * q.prevSignal < 0)
{
if (signal > 0)
{
rawWaveLengthAscend[c] = ascendSamples[c];
ascendSamples[c] = 0;
q.rawWaveLengthAscend = q.ascendSamples;
q.ascendSamples = 0;
}
else
{
rawWaveLengthDescend[c] = descendSamples[c];
descendSamples[c] = 0;
q.rawWaveLengthDescend = q.descendSamples;
q.descendSamples = 0;
}
}

prevSignal[c] = signal;
q.prevSignal = signal;
}

updateWaveLength(c, rawWaveLengthAscend[c]);
updateWaveLength(c, rawWaveLengthDescend[c]);
updateWaveLength(c, q.rawWaveLengthAscend);
updateWaveLength(c, q.rawWaveLengthDescend);

// Convert wavelength [samples] to frequency [Hz] to pitch [V/OCT].
// samplerate/wavelength: [samples/sec]/[samples] = [1/sec] = [Hz]
if (filteredWaveLength[c] > sampleRateHz/4000)
if (q.filteredWaveLength > sampleRateHz/4000)
{
value_t frequencyHz = sampleRateHz / filteredWaveLength[c];
value_t frequencyHz = sampleRateHz / q.filteredWaveLength;
outPitchVoct[c] = std::log2(frequencyHz / centerFrequencyHz);
}
}
Expand Down

0 comments on commit e483352

Please sign in to comment.