Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Quantization Aware Training for BEIT, Eva, and SwinTransformerV2 #2098

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

clementpoiret
Copy link

Dear all,

When trying to perform Quantization Aware Training (QAT), modules are being wrapped with a QuantWrapper.

But, because some models are implementing qkv with biases using torch.nn.functional, one has to call self.qkv.weights.

During QAT, self.qkv.weights becomes undefined, as in the error below:

traceback
Traceback (most recent call last):
File "/workspace/main.py", line 127, in main
  trainer.fit(model, data_module)
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 544, in fit
  call._call_and_handle_interrupt(
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt
  return trainer_fn(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 580, in _fit_impl
  self._run(model, ckpt_path=ckpt_path)
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 989, in _run
  results = self._run_stage()
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1035, in _run_stage
  self.fit_loop.run()
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py", line 202, in run
  self.advance()
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py", line 359, in advance
  self.epoch_loop.run(self._data_fetcher)
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 136, in run
  self.advance(data_fetcher)
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 240, in advance
  batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 187, in run
  self._optimizer_step(batch_idx, closure)
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 265, in _optimizer_step
  call._call_lightning_module_hook(
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 157, in _call_lightning_module_hook
  output = fn(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/core/module.py", line 1282, in optimizer_step
  optimizer.step(closure=optimizer_closure)
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/core/optimizer.py", line 151, in step
  step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 230, in optimizer_step
  return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/plugins/precision/amp.py", line 77, in optimizer_step
  closure_result = closure()
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 140, in __call__
  self._result = self.closure(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
  return func(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 126, in closure
  step_output = self._step_fn()
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 315, in _training_step
  training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 309, in _call_strategy_hook
  output = fn(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 382, in training_step
  return self.lightning_module.training_step(*args, **kwargs)
File "/workspace/fringuantai/model.py", line 611, in training_step
  y_hat, confidence, y_hat_ctx = self(inputs, context=context)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
  return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
  return forward_call(*args, **kwargs)
File "/workspace/fringuantai/model.py", line 594, in forward
  return self.model(x, context=context, target_context=target_context)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
  return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
  return forward_call(*args, **kwargs)
File "/workspace/fringuantai/model.py", line 486, in forward
  features = self.backbone(x)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
  return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
  return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/timm/models/eva.py", line 580, in forward
  x = self.forward_features(x)
File "/opt/conda/lib/python3.10/site-packages/timm/models/eva.py", line 568, in forward_features
  x = blk(x, rope=rot_pos_embed)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
  return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
  return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/timm/models/eva.py", line 236, in forward
  x = x + self.drop_path1(self.attn(self.norm1(x), rope=rope, attn_mask=attn_mask))
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
  return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
  return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/timm/models/eva.py", line 113, in forward
  qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1695, in __getattr__
  raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'QuantWrapper' object has no attribute 'weight'

The fix is pretty straightforward: just remove the calls to self.qkv.weights, and directly add qkv_bias to self.qkv(x) when possible.

@rwightman
Copy link
Collaborator

@clementpoiret so concern with this, it changes a kernen that's likely a fused addmm into two separate kernels which isn't great. Aren't there more advanced quanitization options these days that can handle nn.Module an functional calls equally well?

@clementpoiret
Copy link
Author

Hmm I see. I can't really find a pretty solution to that point. Actually, F.linear is well quatized (although slower than using nn.Linear in the current implementation).

Given the implementation of QuantWrapper, the only fix that comes to my mind is to call F.linear(self.qkv.module.weights, qkv_bias) if self.qkv is an instance of a QuantWrapper...

But it's including a fix for a very specific use-case in a model definition that shouldn't be more complex than necessary. I'm not really happy with this solution. What do you think about it? Another solution is simply to document the issue to let the user know he have to override the forward pass if he wants to quantize it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants