New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support partial hypotheses for batched RNN-T and TDT decoding #9106
base: main
Are you sure you want to change the base?
Conversation
…abel-Looping algorithm) Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly some FYI comments for now. Looks good to me overall.
@@ -327,10 +327,11 @@ def autocast(): | |||
# configure the decoding config | |||
decoding_cfg = asr_model.cfg.decoding | |||
with open_dict(decoding_cfg): | |||
decoding_cfg.strategy = "greedy" | |||
decoding_cfg.strategy = "greedy_batch" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's very confusing, but there is no "greedy_batch" strategy only for CTC, so if you ran with --decoder_type=ctc
, I think your could would crash.
The original logic is all messed up, but this ended up working for me for both CTC and RNN-T, IIRC:
@@ -325,14 +316,33 @@ def main():
yield
# configure the decoding config
+
+ if args.set_decoder is not None:
+ if hasattr(asr_model, "cur_decoder"):
+ decoder_type = args.set_decoder
+ else:
+ raise ValueError("Decoder cannot get changed for non-Hybrid ASR models.")
+ else:
+ decoder_type = None
+
+
decoding_cfg = asr_model.cfg.decoding
with open_dict(decoding_cfg):
decoding_cfg.strategy = "greedy"
decoding_cfg.preserve_alignments = False
- if hasattr(asr_model, 'joint'): # if an RNNT model
- decoding_cfg.greedy.max_symbols = 10
+ if decoder_type == "rnnt": # if an RNNT model
+ # We need partial hypothesis support here...
+ decoding_cfg.strategy = "greedy_batch"
decoding_cfg.fused_batch_size = -1
- asr_model.change_decoding_strategy(decoding_cfg)
+ decoding_cfg.greedy.max_symbols_per_step = 10
+ decoding_cfg.greedy.loop_labels = True
+ # TODO: Why isn't this working???
+ decoding_cfg.greedy.use_cuda_graph_decoder = True
+ # import ipdb; ipdb.set_trace()
+ elif decoder_type == "ctc":
+ decoding_cfg.greedy.batched_inference = True
+
+ asr_model.change_decoding_strategy(decoding_cfg, decoder_type=decoder_type)
asr_model = asr_model.to(args.device)
asr_model.eval()
Note that the "batched_inference = True" thing comes from this PR: #9100 . So you could delete that line for now. No need to apply this diff unless you wanted to run a CTC model.
BTW, this is how I'm running this same script using CTC right now:
speech_to_text_cache_aware_streaming_infer.py --asr_model=stt_en_fastconformer_hybrid_large_streaming_multi --manifest_file=/home/dgalvez/scratch/data/test_other_sorted_downward.json --us\
e_amp --set_decoder ctc
(I am working on improving streaming perf for CTC right now, so if you want to work on that as well, let me know first.)
(L x B x H, L x B x H) | ||
""" | ||
return ( | ||
torch.stack([state[0] for state in batch_states], dim=1).to(device=device, dtype=dtype), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In an ideal world, I think we would have a "streaming session" that holds a single contiguous memory buffer that can hold batch_size * number_of_batches
hidden states. Then you could implement these operations via scatter kernels and gather kernels rather than splitting and concatenating. However, I don't think it is easy for NeMo to do this the way that is currently written.
] | ||
|
||
@classmethod | ||
def batch_unsplit_states( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will need to make this an abstract method in rnnt_abstract.py, right?
if partial_hypotheses is not None: | ||
raise NotImplementedError("`partial_hypotheses` support is not implemented") | ||
prev_labels = torch.tensor( | ||
[hyp.y_sequence[-1] if len(hyp.y_sequence) else self._blank_index for hyp in partial_hypotheses] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are self._blank_index and start of sequence token always guaranteed to the same?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For blank as pad
- yes (that's why these methods are called *_decode_blank_as_pad_*
)
|
||
if partial_hypotheses: | ||
for prev_hyp, hyp in zip(partial_hypotheses, hyps): | ||
hyp.y_sequence = torch.cat((prev_hyp.y_sequence, hyp.y_sequence)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't believe this is enough, if we want the results to match.
This is what I did for CTC (WIP!)
+ if previous_hypotheses is not None:
+ for i, hypothesis in enumerate(hypotheses_list):
+ # I am mutating previous_hypothses here... Perhaps not the wisest idea
+ previous_hypotheses[i].merge(hypothesis)
+ hypotheses_list = previous_hypotheses
This is the implementation of merge():
modified nemo/collections/asr/parts/utils/rnnt_utils.py
@@ -139,6 +139,27 @@ class Hypothesis:
"""
return [] if self.text is None else self.text.split()
+ # TODO: Does htis need to run inside of torch.inference_mode() or torch.no_grad()?
+ def merge(self, other):
+ self.score += other.score
+ # TODO: Consider what to do if this is a tensor, not a list. Concatenate?
+ self.y_sequence.extend(other.y_sequence)
+ self.dec_state = other.dec_state
+ if self.timestep is not None:
+ self.timestep.extend(other.timestep)
+ self.length += other.length
+ self.last_token = other.last_token
+
+ # TODO: Concatenate for alignments and frame_confidence.
+ if self.alignments is not None:
+ self.alignments[0] = torch.cat(self.alignments[0], other.alignments[0])
+ self.alignments[1] = torch.cat(self.alignments[1], other.alignments[1])
+ if self.frame_confidence is not None:
+ self.frame_confidence.extend(other.frame_confidence)
+
+ # Invalidated. Need to rerun decode_hypothesis here.
+ self.text = None
+
@dataclass
class NBestHypotheses:
# last found labels - initially <SOS> (<blank>) symbol | ||
self.state.labels.fill_(self._SOS) | ||
else: | ||
self.state.labels.copy_(prev_labels, non_blocking=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If these are both cuda tensors, I think non_blocking=True is unnecessary.
src_states=prev_state, dst_states=self.state.decoder_state, | ||
) | ||
|
||
if prev_labels is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In an ideal implementation, you would be able to support "continuous batching". What I mean is, if a single element finishes before the others in the batch, then a new element will be placed into the batch. I realize that NeMo isn't ready for this yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prev_labels
is a batch of the "last labels from the previous decoding iteration".
If you want to start a new sequence ("continuous batching"), just place <SOS>
symbol for the corresponding item (=blank for all our models, blank_as_pad=True
). This is already supported in this PR, see
https://github.com/NVIDIA/NeMo/pull/9106/files#diff-c695cda936f43469cd972499fb0bf557b5abbd852395e2c5f52faae37c95f4eeR752
We also need to make zero initial state for this item in the batch, I'm also planning to implement this.
@titu1994 if you could rereview at your convenience, that would be appreciated. |
This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days. |
What does this PR do ?
Add a one line overview of what this PR aims to accomplish.
Collection: [Note which collection this PR will affect]
Changelog
Usage
# Add a code snippet demonstrating how to use this
GitHub Actions CI
The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.
The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".
Before your PR is "Ready for review"
Pre checks:
PR Type:
If you haven't finished some of the above items you can still open "Draft" PR.
Who can review?
Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.
Additional Information