Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gammatone Filterbank waveform outputs #128

Open
astrocyted opened this issue Mar 31, 2023 · 4 comments
Open

Gammatone Filterbank waveform outputs #128

astrocyted opened this issue Mar 31, 2023 · 4 comments

Comments

@astrocyted
Copy link

Hi,
I'm interested to have an nn.module gammatone Filterbank that produces the filtered outputs directly [N_filters X Signal length], would it be possible to achieve it within your framwork and without having to go through the loop of number of filters?

@KinWaiCheuk
Copy link
Owner

Isn't Gammatonegram already like this? The output is already [Batch, N_filters X Signal length]. Or am I understanding your question wrongly?

@astrocyted
Copy link
Author

Isn't Gammatonegram already like this? The output is already [Batch, N_filters X Signal length]. Or am I understanding your question wrongly?

Just saw your reply now. No. clearly thats not what Gammatonegram returns. check the docs of yourt code:
Returns
-------
spectrogram : torch.tensor
It returns a tensor of spectrograms. shape = (num_samples, freq_bins,time_steps).

time_steps is not signal length, but rather signal_length/frame_hop, i want the per channel IIR filtered waveform not the binned fft

@KinWaiCheuk
Copy link
Owner

I understand your question now. I am not familiar with gammatone and gammatonegram. This feature is implemented by @WangHelin1997. Maybe he can comment more on it?

Alternatively, can you recommend me any python library that could produce the filtered waveforms? I will check if I could implement it under the current nnAudio framework. It would be a great help if I have something to refer to just to check if I could implement it correctly.

@astrocyted
Copy link
Author

https://github.com/detly/gammatone/blob/master/gammatone/filters.py

This is one example of its implementation. the output of erb_filterbank() function is what im asking for. its quite slow though. I tried to do it in torch too myself but not really sped up:

class GammatoneFilterbank(torch.nn.Module):
    def __init__(self,
                num_filters=64,
                sample_rate=16000,
                fmin= 50,
                fmax = None,
                gtgram = False,
                frame_length = 400,
                hop_length= 160    
                ):
        
        super(GammatoneFilterbank, self).__init__()
        self.num_filters = num_filters
        self.sample_rate = sample_rate
        
        self.gtgram = gtgram
        self.frame_length = frame_length
        self.hop_length = hop_length

        self.fmin = fmin
        if fmax:
            self.fmax = fmax
        else:
            self.fmax = self.sample_rate/2

        self.centre_freqs = self.centre_frequencies()
        self.filter_coefs = self.make_erb_filters()
   

    @staticmethod
    def erb_point(low_freq, high_freq, fraction):
        ear_q = 9.26449  # Glasberg and Moore Parameters
        min_bw = 24.7
        order = 1
        
        low_freq = torch.tensor(low_freq)
        high_freq = torch.tensor(high_freq)
        
        
        erb_point = (
            -ear_q * min_bw
            + torch.exp(
                fraction * (
                    -torch.log(high_freq + ear_q * min_bw)
                    + torch.log(low_freq + ear_q * min_bw)
                )
            ) *
            (high_freq + ear_q * min_bw)
        )
        
        return erb_point

    @staticmethod
    def erb_space(
        low_freq=50,
        high_freq=8000,
        num_bands=64):
        """
        This function computes an array of ``num`` frequencies uniformly spaced
        between ``high_freq`` and ``low_freq`` on an ERB scale.
        
        For a definition of ERB, see Moore, B. C. J., and Glasberg, B. R. (1983).
        "Suggested formulae for calculating auditory-filter bandwidths and
        excitation patterns," J. Acoust. Soc. Am. 74, 750-753.
        """
        return GammatoneFilterbank.erb_point(
            low_freq,
            high_freq,
            torch.arange(1, num_bands + 1) / num_bands
            )

    
    def centre_frequencies(self):
        """
        Calculates an array of centre frequencies (for :func:`make_erb_filters`)
        from a sampling frequency, lower cutoff frequency and the desired number of
        filters.
        
        :param fs: sampling rate
        :param num_freqs: number of centre frequencies to calculate
        :type num_freqs: int
        :param cutoff: lower cutoff frequency
        :return: same as :func:`erb_space`
        """
        return GammatoneFilterbank.erb_space(low_freq= self.fmin, high_freq= self.fmax, num_bands=self.num_filters)



    def make_erb_filters(self, width=1.0):
        T = 1 / self.sample_rate
        ear_q = 9.26449 # Glasberg and Moore Parameters
        min_bw = 24.7
        order = 1

        if not torch.is_tensor(self.centre_freqs):
            self.centre_freqs = torch.Tensor(self.centre_freqs)
        
        erb = width*((self.centre_freqs / ear_q) ** order + min_bw ** order) ** (1 / order)
        B = 1.019 * 2 * torch.Tensor([math.pi]) * erb

        arg = 2 * self.centre_freqs * torch.Tensor([math.pi]) * T
        vec = torch.exp(2j * arg)

        A0 = T
        A2 = 0
        B0 = 1
        B1 = -2 * torch.cos(arg) / torch.exp(B * T)
        B2 = torch.exp(-2 * B * T)

        rt_pos = torch.sqrt(torch.tensor(3 + 2 ** 1.5))
        rt_neg = torch.sqrt(torch.tensor(3 - 2 ** 1.5))

        common = -T * torch.exp(-(B * T))

        k11 = torch.cos(arg) + rt_pos * torch.sin(arg)
        k12 = torch.cos(arg) - rt_pos * torch.sin(arg)
        k13 = torch.cos(arg) + rt_neg * torch.sin(arg)
        k14 = torch.cos(arg) - rt_neg * torch.sin(arg)

        A11 = common * k11
        A12 = common * k12
        A13 = common * k13
        A14 = common * k14

        gain_arg = torch.exp(1j * arg - B * T)

        gain = torch.abs(
            (vec - gain_arg * k11)
            * (vec - gain_arg * k12)
            * (vec - gain_arg * k13)
            * (vec - gain_arg * k14)
            * (T * torch.exp(B * T)
                / (-1 / torch.exp(B * T) + 1 + vec * (1 - torch.exp(B * T)))
            )**4
        )

        allfilts = torch.ones_like(self.centre_freqs)

        fcoefs = torch.stack([
            A0 * allfilts, A11, A12, A13, A14, A2*allfilts,
            B0 * allfilts, B1, B2,
            gain
        ], dim=1)

        return fcoefs

    def erb_filterbank(self, waveform):
        #Batch x Time
        if waveform.ndim==1:
            waveform = waveform[None,:]

        #output = torch.zeros((self.filter_coefs[:,9].shape[0], waveform.shape[-1]))
        
        gain = self.filter_coefs[:, 9]
        # A0, A11, A2
        As1 = self.filter_coefs[:, (0, 1, 5)]
        # A0, A12, A2
        As2 = self.filter_coefs[:, (0, 2, 5)]
        # A0, A13, A2
        As3 = self.filter_coefs[:, (0, 3, 5)]
        # A0, A14, A2
        As4 = self.filter_coefs[:, (0, 4, 5)]
        # B0, B1, B2
        Bs = self.filter_coefs[:, 6:9]
        
        stacked_waveforms = waveform.expand(self.filter_coefs.shape[0],*waveform.shape[1:])

        y1 = F.lfilter(stacked_waveforms, Bs, As1, clamp=False)
        y2 = F.lfilter(y1, Bs, As2,clamp=False)
        y3 = F.lfilter(y2, Bs, As3,clamp=False)
        y4 = F.lfilter(y3, Bs, As4,clamp=False)
        
            
        return y4 / gain.unsqueeze(-1)
    
    def forward(self, x):
        if self.gtgram:
            x = self.erb_filterbank(x)
            x = torch.nn.functional.pad(x,(self.frame_length//2, self.frame_length - self.frame_length//2))
            x = torch.sum(x.unfold(-1, self.frame_length, self.hop_length)**2, axis=-1)
            return torch.sqrt(x)
        else:
            return self.erb_filterbank(x)

I guess the fastest ones are the ones directly written in C.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants