-
Notifications
You must be signed in to change notification settings - Fork 512
/
base_audio_classification.py
41 lines (34 loc) · 1.32 KB
/
base_audio_classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
#
import argparse
from corenet.modeling.models import MODEL_REGISTRY, BaseAnyNNModel
@MODEL_REGISTRY.register(name="__base__", type="audio_classification")
class BaseAudioClassification(BaseAnyNNModel):
"""Base class for audio classification.
Args:
opts: Command-line arguments
"""
def __init__(self, opts: argparse.Namespace, *args, **kwargs) -> None:
super().__init__(opts, *args, **kwargs)
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Add model specific arguments"""
if cls != BaseAudioClassification:
# Don't re-register arguments in subclasses that don't override `add_arguments()`.
return parser
group = parser.add_argument_group(title=cls.__name__)
group.add_argument(
"--model.audio-classification.name",
type=str,
default=None,
help="Name of the audio classification model. Defaults to None.",
)
group.add_argument(
"--model.audio-classification.pretrained",
type=str,
default=None,
help="Path of the pretrained backbone. Defaults to None.",
)
return parser