From e483352237865aa10e800582afad4eacd1ec6714 Mon Sep 17 00:00:00 2001 From: Don Cross Date: Sun, 2 Feb 2025 12:37:24 -0500 Subject: [PATCH] Refactored env/pitch to avoid unsafe array indexing. Moved all the things we needed per channel into an "info" struct. Changed a bunch of arrays into a single std::vector. Use std::vector::at(index) to validate every memory access into a reference 'q'. Use 'q' for all read/write to a given channel. --- src/env_pitch_detect.hpp | 119 +++++++++++++++++++++++---------------- 1 file changed, 69 insertions(+), 50 deletions(-) diff --git a/src/env_pitch_detect.hpp b/src/env_pitch_detect.hpp index a966888..36a0ab4 100644 --- a/src/env_pitch_detect.hpp +++ b/src/env_pitch_detect.hpp @@ -4,6 +4,45 @@ namespace Sapphire { + template + struct EnvPitchChannelInfo + { + value_t prevSignal; + int ascendSamples; // wavelength sample counters between consecutive ascending zero-crossings + int descendSamples; // wavelength sample counters between consecutive descending zero-crossings + int rawWaveLengthAscend; + int rawWaveLengthDescend; + value_t filteredWaveLength; + + using filter_t = StagedFilter; + filter_t loCutFilter; + filter_t hiCutFilter; + filter_t jitterFilter; + filter_t amplFilter; + + EnvPitchChannelInfo() + { + initialize(); + } + + void initialize() + { + prevSignal = 0; + ascendSamples = 0; + descendSamples = 0; + rawWaveLengthAscend = 0; + rawWaveLengthDescend = 0; + filteredWaveLength = 0; + + // Reset all filters in case they went non-finite. + loCutFilter.Reset(); + hiCutFilter.Reset(); + jitterFilter.Reset(); + amplFilter.Reset(); + } + }; + + template class EnvPitchDetector { @@ -11,33 +50,23 @@ namespace Sapphire static_assert(maxChannels > 0); float centerFrequencyHz = 261.6255653005986; // note C4 = 440 / (2**(3/4)) - int currentSampleRate = 0; - - value_t prevSignal[maxChannels]; - int ascendSamples[maxChannels]; // wavelength sample counters between consecutive ascending zero-crossings - int descendSamples[maxChannels]; // wavelength sample counters between consecutive descending zero-crossings - int rawWaveLengthAscend[maxChannels]; - int rawWaveLengthDescend[maxChannels]; - value_t filteredWaveLength[maxChannels]; - - using filter_t = StagedFilter; - filter_t loCutFilter[maxChannels]; - filter_t hiCutFilter[maxChannels]; - filter_t jitterFilter[maxChannels]; - filter_t amplFilter[maxChannels]; float loCutFrequency = 20; float hiCutFrequency = 3000; float jitterCornerFrequency = 10; float amplCornerFrequency = 10; - int recoveryCountdown; // how many samples remain before trying to filter again (CPU usage limiter) + int recoveryCountdown = 0; // how many samples remain before trying to filter again (CPU usage limiter) + + using info_t = EnvPitchChannelInfo; + std::vector info; value_t updateAmplitude(int channel, value_t signal) { // Square the signal and filter the result. // This gives us a time-smeared measure of power. - amplFilter[channel].SetCutoffFrequency(amplCornerFrequency); - return amplFilter[channel].UpdateLoPass(signal*signal, currentSampleRate); + info_t& q = info.at(channel); + q.amplFilter.SetCutoffFrequency(amplCornerFrequency); + return q.amplFilter.UpdateLoPass(signal*signal, currentSampleRate); } void updateWaveLength(int channel, int wavelengthSamples) @@ -59,34 +88,22 @@ namespace Sapphire if (rawFrequencyHz < loCutFrequency || rawFrequencyHz > hiCutFrequency) return; - jitterFilter[channel].SetCutoffFrequency(jitterCornerFrequency); - filteredWaveLength[channel] = jitterFilter[channel].UpdateLoPass(wavelengthSamples, currentSampleRate); + info_t& q = info.at(channel); + q.jitterFilter.SetCutoffFrequency(jitterCornerFrequency); + q.filteredWaveLength = info[channel].jitterFilter.UpdateLoPass(wavelengthSamples, currentSampleRate); } public: EnvPitchDetector() { - initialize(); + info.resize(maxChannels); } void initialize() { recoveryCountdown = 0; for (int c = 0; c < maxChannels; ++c) - { - prevSignal[c] = 0; - ascendSamples[c] = 0; - descendSamples[c] = 0; - rawWaveLengthAscend[c] = 0; - rawWaveLengthDescend[c] = 0; - filteredWaveLength[c] = 0; - - // Reset all filters in case they went non-finite. - loCutFilter[c].Reset(); - hiCutFilter[c].Reset(); - jitterFilter[c].Reset(); - amplFilter[c].Reset(); - } + info[c].initialize(); } void process( @@ -125,20 +142,22 @@ namespace Sapphire for (int c = 0; c < numChannels; ++c) { - ++ascendSamples[c]; - ++descendSamples[c]; + info_t& q = info.at(c); + + ++q.ascendSamples; + ++q.descendSamples; // Feed through a bandpass filter that rejects DC and other frequencies below 20 Hz, // and also rejects very high frequencies. // Reject frequencies lower than we want to keep. - loCutFilter[c].SetCutoffFrequency(loCutFrequency); - value_t locut = loCutFilter[c].UpdateHiPass(inFrame[c], sampleRateHz); + q.loCutFilter.SetCutoffFrequency(loCutFrequency); + value_t locut = q.loCutFilter.UpdateHiPass(inFrame[c], sampleRateHz); // Reject frequencies higher than we want to keep. // The band-pass result is our signal to feed through envelope and pitch detection. - hiCutFilter[c].SetCutoffFrequency(hiCutFrequency); - value_t signal = hiCutFilter[c].UpdateLoPass(locut, sampleRateHz); + q.hiCutFilter.SetCutoffFrequency(hiCutFrequency); + value_t signal = q.hiCutFilter.UpdateLoPass(locut, sampleRateHz); // Make sure we have a normal numeric value for our signal. if (!std::isfinite(signal)) @@ -161,31 +180,31 @@ namespace Sapphire // Find (both ascending and descending), independently for each channel. // Measure the interval between them, expressed in samples, called "wavelength". - if (signal * prevSignal[c] < 0) + if (signal * q.prevSignal < 0) { if (signal > 0) { - rawWaveLengthAscend[c] = ascendSamples[c]; - ascendSamples[c] = 0; + q.rawWaveLengthAscend = q.ascendSamples; + q.ascendSamples = 0; } else { - rawWaveLengthDescend[c] = descendSamples[c]; - descendSamples[c] = 0; + q.rawWaveLengthDescend = q.descendSamples; + q.descendSamples = 0; } } - prevSignal[c] = signal; + q.prevSignal = signal; } - updateWaveLength(c, rawWaveLengthAscend[c]); - updateWaveLength(c, rawWaveLengthDescend[c]); + updateWaveLength(c, q.rawWaveLengthAscend); + updateWaveLength(c, q.rawWaveLengthDescend); // Convert wavelength [samples] to frequency [Hz] to pitch [V/OCT]. // samplerate/wavelength: [samples/sec]/[samples] = [1/sec] = [Hz] - if (filteredWaveLength[c] > sampleRateHz/4000) + if (q.filteredWaveLength > sampleRateHz/4000) { - value_t frequencyHz = sampleRateHz / filteredWaveLength[c]; + value_t frequencyHz = sampleRateHz / q.filteredWaveLength; outPitchVoct[c] = std::log2(frequencyHz / centerFrequencyHz); } }