Skip to content

bootphon/learnable-strf

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 

Repository files navigation

Learning spectro-temporal representations of complex sounds with parameterized neural networks

image

c is an open-source package to replicate the experiments in [^1].

Dependencies

The main dependencies of learnable-strf are :

Implementation of the Learnable STRF

The Learnable STRF can be easily implemented in pytorch and is inspired by the implementation from this package

We used the nnAudio package to obtain the log Mel Filterbanks.

from typing import Optional
from typing import Text

import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from nnAudio import Spectrogram
from torch.nn.utils.rnn import PackedSequence

from torch.nn.modules.conv import _ConvNd
from torch.nn.parameter import Parameter
from torch.nn.modules.utils import _pair

class STRFConv2d(_ConvNd):
  def __init__(self,
               in_channels,
               out_channels,
               kernel_size,
               stride=1,
               padding=0,
               dilation=1,
               groups=1,
               bias=False,
               padding_mode='zeros',
               device=None,
               n_features=64):

      stride = _pair(stride)
      padding = _pair(padding)
      dilation = _pair(dilation)

      super(STRFConv2d,
            self).__init__(in_channels, out_channels,
                           kernel_size, stride, padding, dilation, False,
                           _pair(0), groups, bias, padding_mode)
      self.n_features = n_features

      self.theta = np.random.vonmises(0, 0, (out_channels, in_channels))
      self.gamma = np.random.vonmises(0, 0, (out_channels, in_channels))
      self.psi = np.random.vonmises(0, 0, (out_channels, in_channels))
      self.gamma = nn.Parameter(torch.Tensor(self.gamma))
      self.psi = nn.Parameter(torch.Tensor(self.psi))
      self.freq = (np.pi / 2) * 1.41**(
          -np.random.uniform(0, 5, size=(out_channels, in_channels)))

      self.freq = nn.Parameter(torch.Tensor(self.freq))
      self.theta = nn.Parameter(torch.Tensor(self.theta))

      self.sigma_x = 2 * 1.41**(np.random.uniform(
          0, 6, (out_channels, in_channels)))
      self.sigma_x = nn.Parameter(torch.Tensor(self.sigma_x))
      self.sigma_y = 2 * 1.41**(np.random.uniform(
          0, 6, (out_channels, in_channels)))
      self.sigma_y = nn.Parameter(torch.Tensor(self.sigma_y))
      self.f0 = torch.ceil(torch.Tensor([self.kernel_size[0] / 2]))[0]
      self.t0 = torch.ceil(torch.Tensor([self.kernel_size[1] / 2]))[0]

  def forward(self, sequences, use_real=True):
      packed_sequences = isinstance(sequences, PackedSequence)
      if packed_sequences:
          device = sequences.data.device
      else:
          device = sequences.device
      sequences = sequences.reshape(
          sequences.size(0), 1, self.n_features, -1)
      grid = [
          torch.linspace(-self.f0 + 1, self.f0, self.kernel_size[0]),
          torch.linspace(-self.t0 + 1, self.t0, self.kernel_size[1])
      ]
      f, t = torch.meshgrid(grid)
      f = f.to(device)
      t = t.to(device)
      weight = torch.empty(self.weight.shape, requires_grad=False)
      for i in range(self.out_channels):
          for j in range(self.in_channels):
              sigma_x = self.sigma_x[i, j].expand_as(t)
              sigma_y = self.sigma_y[i, j].expand_as(t)
              freq = self.freq[i, j].expand_as(t)
              theta = self.theta[i, j].expand_as(t)
              gamma = self.gamma[i, j].expand_as(t)
              psi = self.psi[i, j].expand_as(t)
              rotx = t * torch.cos(theta) + f * torch.sin(theta)
              roty = -t * torch.sin(theta) + f * torch.cos(theta)
              rot_gamma = t * torch.cos(gamma) + f * torch.sin(gamma)
              g = torch.zeros(t.shape)
              g = torch.exp(-0.5 * ((f**2) / (sigma_x + 1e-3)**2 +
                                    (t**2) / (sigma_y + 1e-3)**2))
              if use_real:
                  g = g * torch.cos(freq * rot_gamma)
              else:
                  g = g * torch.sin(freq * rot_gamma)
              g = g / (2 * np.pi * sigma_x * sigma_y)
              weight[i, j] = g
              self.weight.data[i, j] = g
      weight = weight.to(device)
      return F.conv2d(sequences, weight, self.bias, self.stride,
                      self.padding, self.dilation, self.groups)

Replication of the engineering experiments

Speech Activity detection

All the experiments for Speech Activity Detection are run with the pyannote ecosystem.

Databases

The AMI corpus can be obtained freely from https://groups.inf.ed.ac.uk/ami/corpus/.

The CHIME5 corpus can be obtained freely from url

The protocol databases for the train/dev/test are AMI.SpeakerDiarization.MixHeadset and CHiME5.SpeakerDiarization.U01 and can be obtained via pyannote.database and the pipcommands for AMI and for [CHIME5] it is required the following lines to your .pyannote/database.yml.

  CHiME5:
    SpeakerDiarization:
      U01:
        train:
          annotation: /export/fs01/jsalt19/databases/CHiME5/train/allU01_train.rttm
          annotated: /export/fs01/jsalt19/databases/CHiME5/train/allU01_train.uem
        development:
          annotation: /export/fs01/jsalt19/databases/CHiME5/dev/allU01_dev.rttm
          annotated: /export/fs01/jsalt19/databases/CHiME5/dev/allU01_dev.uem
        test:
          annotation: /export/fs01/jsalt19/databases/CHiME5/test/allU01_test.rttm
          annotated: /export/fs01/jsalt19/databases/CHiME5/test/allU01_test.uem

Speaker Verification

We followed the protocol from JM Coria et al., and injected the network STRFTDNN instead of SincTDNN.

Urban Sound Classification

We followed the protocol from Arnault et al., and just modified the Pann architecture by injecting the STRFConv2D on top of the Mel Filterbanks.

Models

The models to run the Speech Activity Detection and Speaker Identification are in the file models.py. This file replaces the models.py in the pyannote.audio package to use the Learnable STRF

Acknowledgements

We are very grateful to authors from Pyannote, nnAudio, urban sound sound package, Theunissen's group, Shamma's group for the open source packages and datasets which made possible this work.

[1] Riad R., Karadyi J., Bachoud-Lévi AC., Dupoux, E. Learning spectro-temporal representations of complex sounds with parameterized neural networks. The Journal of the Acoustical Society of America