Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support partial hypotheses for batched RNN-T and TDT decoding #9106

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

artbataev
Copy link
Collaborator

What does this PR do ?

Add a one line overview of what this PR aims to accomplish.

Collection: [Note which collection this PR will affect]

Changelog

  • Add specific line by line info of high level changes in this PR.

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this 

GitHub Actions CI

The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.

The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

If you haven't finished some of the above items you can still open "Draft" PR.

Who can review?

Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.

Additional Information

  • Related to # (issue)

…abel-Looping algorithm)

Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
@artbataev artbataev requested a review from galv May 3, 2024 14:25
@github-actions github-actions bot added the ASR label May 3, 2024
Copy link
Collaborator

@galv galv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly some FYI comments for now. Looks good to me overall.

@@ -327,10 +327,11 @@ def autocast():
# configure the decoding config
decoding_cfg = asr_model.cfg.decoding
with open_dict(decoding_cfg):
decoding_cfg.strategy = "greedy"
decoding_cfg.strategy = "greedy_batch"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's very confusing, but there is no "greedy_batch" strategy only for CTC, so if you ran with --decoder_type=ctc, I think your could would crash.

The original logic is all messed up, but this ended up working for me for both CTC and RNN-T, IIRC:

@@ -325,14 +316,33 @@ def main():                                                                                                                                                                           
             yield                                                                                                                                                                                          
                                                                                                                                                                                                            
     # configure the decoding config                                                                                                                                                                        
+                                                                                                                                                                                                           
+    if args.set_decoder is not None:                                                                                                                                                                       
+        if hasattr(asr_model, "cur_decoder"):                                                                                                                                                              
+            decoder_type = args.set_decoder                                                                                                                                                                
+        else:                                                                                                                                                                                              
+            raise ValueError("Decoder cannot get changed for non-Hybrid ASR models.")                                                                                                                      
+    else:                                                                                                                                                                                                  
+        decoder_type = None                                                                                                                                                                                
+                                                                                                                                                                                                           
+                                                                                                                                                                                                           
     decoding_cfg = asr_model.cfg.decoding                                                                                                                                                                  
     with open_dict(decoding_cfg):                                                                                                                                                                          
         decoding_cfg.strategy = "greedy"                                                                                                                                                                   
         decoding_cfg.preserve_alignments = False                                                                                                                                                           
-        if hasattr(asr_model, 'joint'):  # if an RNNT model                                                                                                                                                
-            decoding_cfg.greedy.max_symbols = 10                                                                                                                                                           
+        if decoder_type == "rnnt":  # if an RNNT model                                                                                                                                                     
+            # We need partial hypothesis support here...                                                                                                                                                   
+            decoding_cfg.strategy = "greedy_batch"                                                                                                                                                         
             decoding_cfg.fused_batch_size = -1                                                                                                                                                             
-        asr_model.change_decoding_strategy(decoding_cfg)                                                                                                                                                   
+            decoding_cfg.greedy.max_symbols_per_step = 10                                                                                                                                                  
+            decoding_cfg.greedy.loop_labels = True                                                                                                                                                         
+            # TODO: Why isn't this working???                                                                                                                                                              
+            decoding_cfg.greedy.use_cuda_graph_decoder = True                                                                                                                                              
+            # import ipdb; ipdb.set_trace()                                                                                                                                                                
+        elif decoder_type == "ctc":                                                                                                                                                                        
+            decoding_cfg.greedy.batched_inference = True                                                                                                                                                   
+                                                                                                                                                                                                           
+    asr_model.change_decoding_strategy(decoding_cfg, decoder_type=decoder_type)                                                                                                                            
                                                                                                                                                                                                            
     asr_model = asr_model.to(args.device)                                                                                                                                                                  
     asr_model.eval()                                                                                                                                                                                       

Note that the "batched_inference = True" thing comes from this PR: #9100 . So you could delete that line for now. No need to apply this diff unless you wanted to run a CTC model.

BTW, this is how I'm running this same script using CTC right now:

speech_to_text_cache_aware_streaming_infer.py --asr_model=stt_en_fastconformer_hybrid_large_streaming_multi --manifest_file=/home/dgalvez/scratch/data/test_other_sorted_downward.json --us\
e_amp --set_decoder ctc

(I am working on improving streaming perf for CTC right now, so if you want to work on that as well, let me know first.)

(L x B x H, L x B x H)
"""
return (
torch.stack([state[0] for state in batch_states], dim=1).to(device=device, dtype=dtype),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In an ideal world, I think we would have a "streaming session" that holds a single contiguous memory buffer that can hold batch_size * number_of_batches hidden states. Then you could implement these operations via scatter kernels and gather kernels rather than splitting and concatenating. However, I don't think it is easy for NeMo to do this the way that is currently written.

]

@classmethod
def batch_unsplit_states(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will need to make this an abstract method in rnnt_abstract.py, right?

if partial_hypotheses is not None:
raise NotImplementedError("`partial_hypotheses` support is not implemented")
prev_labels = torch.tensor(
[hyp.y_sequence[-1] if len(hyp.y_sequence) else self._blank_index for hyp in partial_hypotheses]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are self._blank_index and start of sequence token always guaranteed to the same?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For blank as pad - yes (that's why these methods are called *_decode_blank_as_pad_*)


if partial_hypotheses:
for prev_hyp, hyp in zip(partial_hypotheses, hyps):
hyp.y_sequence = torch.cat((prev_hyp.y_sequence, hyp.y_sequence))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't believe this is enough, if we want the results to match.

This is what I did for CTC (WIP!)

+            if previous_hypotheses is not None:                                                                                                                                                            
+                for i, hypothesis in enumerate(hypotheses_list):                                                                                                                                           
+                    # I am mutating previous_hypothses here... Perhaps not the wisest idea                                                                                                                 
+                    previous_hypotheses[i].merge(hypothesis)                                                                                                                                               
+                hypotheses_list = previous_hypotheses                                                                                                                                                      

This is the implementation of merge():

modified   nemo/collections/asr/parts/utils/rnnt_utils.py
@@ -139,6 +139,27 @@ class Hypothesis:                                                                                                                                                                      
         """                                                                                                                                                                                                
         return [] if self.text is None else self.text.split()                                                                                                                                              
                                                                                                                                                                                                            
+    # TODO: Does htis need to run inside of torch.inference_mode() or torch.no_grad()?                                                                                                                     
+    def merge(self, other):                                                                                                                                                                                
+        self.score += other.score                                                                                                                                                                          
+        # TODO: Consider what to do if this is a tensor, not a list. Concatenate?                                                                                                                          
+        self.y_sequence.extend(other.y_sequence)                                                                                                                                                           
+        self.dec_state = other.dec_state                                                                                                                                                                   
+        if self.timestep is not None:                                                                                                                                                                      
+            self.timestep.extend(other.timestep)                                                                                                                                                           
+        self.length += other.length                                                                                                                                                                        
+        self.last_token = other.last_token                                                                                                                                                                 
+                                                                                                                                                                                                           
+        # TODO: Concatenate for alignments and frame_confidence.                                                                                                                                           
+        if self.alignments is not None:                                                                                                                                                                    
+            self.alignments[0] = torch.cat(self.alignments[0], other.alignments[0])                                                                                                                        
+            self.alignments[1] = torch.cat(self.alignments[1], other.alignments[1])                                                                                                                        
+        if self.frame_confidence is not None:                                                                                                                                                              
+            self.frame_confidence.extend(other.frame_confidence)                                                                                                                                           
+                                                                                                                                                                                                           
+        # Invalidated. Need to rerun decode_hypothesis here.                                                                                                                                               
+        self.text = None                                                                                                                                                                                   
+                                                                                                                                                                                                           
                                                                                                                                                                                                            
 @dataclass                                                                                                                                                                                                 
 class NBestHypotheses:                                                                                                                                                                                     

# last found labels - initially <SOS> (<blank>) symbol
self.state.labels.fill_(self._SOS)
else:
self.state.labels.copy_(prev_labels, non_blocking=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If these are both cuda tensors, I think non_blocking=True is unnecessary.

src_states=prev_state, dst_states=self.state.decoder_state,
)

if prev_labels is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In an ideal implementation, you would be able to support "continuous batching". What I mean is, if a single element finishes before the others in the batch, then a new element will be placed into the batch. I realize that NeMo isn't ready for this yet.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prev_labels is a batch of the "last labels from the previous decoding iteration".
If you want to start a new sequence ("continuous batching"), just place <SOS> symbol for the corresponding item (=blank for all our models, blank_as_pad=True). This is already supported in this PR, see
https://github.com/NVIDIA/NeMo/pull/9106/files#diff-c695cda936f43469cd972499fb0bf557b5abbd852395e2c5f52faae37c95f4eeR752

We also need to make zero initial state for this item in the batch, I'm also planning to implement this.

@galv
Copy link
Collaborator

galv commented May 6, 2024

@titu1994 if you could rereview at your convenience, that would be appreciated.

Copy link
Contributor

This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days.

@github-actions github-actions bot added the stale label May 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants