Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ideas behind sharing parameters of policy model and value model? #1563

Open
MagiaSN opened this issue Nov 19, 2023 · 7 comments
Open

Ideas behind sharing parameters of policy model and value model? #1563

MagiaSN opened this issue Nov 19, 2023 · 7 comments
Labels
enhancement New feature or request pending This problem is yet to be addressed.

Comments

@MagiaSN
Copy link

MagiaSN commented Nov 19, 2023

Thanks for this great work, but I wonder why we are sharing parameters of policy model and value model here. In most literature, the policy model is initialized with the SFT model, and the value model is initialized with the reward model, which have separate parameters. I also made some experiment, under same data and hyperparameters:

  1. Share parameters, initialized with SFT model (default in this repo)
    image
    image

  2. Separate parameters, initialized with SFT model
    image
    image

  3. Separate parameters, policy model initialized with SFT model, value model initialized with reward model (default in most literature)
    image
    image

It seems if value model is initialized with the reward model, the initial value loss is much lower, and the achieved reward is better.

The implementation for separating parameters is quite straight forward, first add a value adapter and load from reward ckpt:

model.pretrained_model.load_adapter(reward_model_path, "value", is_trainable=True)

Then run forward twice with different adapter to get logits and values:

unwrapped_model.pretrained_model.set_adapter("value")
_, _, values = model(**input_kwargs)
unwrapped_model.pretrained_model.set_adapter("default")
logits, _, _ = model(**input_kwargs)
@hiyouga hiyouga added enhancement New feature or request pending This problem is yet to be addressed. labels Nov 19, 2023
@luo-li-ba-suo
Copy link

If you do that, I think you have to maintain two different full-parameter models. But it seems not:
unwrapped_model.pretrained_model.set_adapter("value") _, _, values = model(**input_kwargs) unwrapped_model.pretrained_model.set_adapter("default") logits, _, _ = model(**input_kwargs)

@hiyouga
Copy link
Owner

hiyouga commented Nov 21, 2023

@luo-li-ba-suo The use of adapters means that it is a LoRA tuning instead of full-parameter tuning.

@luo-li-ba-suo
Copy link

You are right. I just wonder whether two different LoRA Adapter on one model can be trained simultaneously.

@hiyouga
Copy link
Owner

hiyouga commented Nov 21, 2023

We should prevent the gradient of adapters from being disabled if we use multiple adapters in a PPO step.
https://github.com/BenjaminBossan/peft/blob/v0.6.1/src/peft/tuners/tuners_utils.py#L341

@MagiaSN
Copy link
Author

MagiaSN commented Nov 24, 2023

@hiyouga I didn't know peft would disable grads for inactive adapter when I was doing the above experiments. Today I checked adapter_model.bin of the value adapter, and found it was indeed updated:

Reward weights which the value model was initialized with:

>>> reward_model_state_dict["base_model.model.transformer.encoder.layers.0.self_attention.query_key_value.lora_A.weight"]
tensor([[ 0.0203,  0.0008, -0.0167,  ...,  0.0026, -0.0003, -0.0018],
        [ 0.0007,  0.0115,  0.0013,  ...,  0.0174, -0.0172, -0.0120],
        [ 0.0075, -0.0182, -0.0136,  ...,  0.0130, -0.0143,  0.0088],
        ...,
        [ 0.0058,  0.0050,  0.0010,  ..., -0.0221,  0.0185, -0.0135],
        [ 0.0024,  0.0229, -0.0018,  ..., -0.0061,  0.0127, -0.0071],
        [-0.0103,  0.0158, -0.0023,  ..., -0.0065, -0.0201, -0.0134]])

And weights of value model after PPO:

>>> value_model_state_dict["base_model.model.transformer.encoder.layers.0.self_attention.query_key_value.lora_A.weight"]
tensor([[ 0.0204,  0.0003, -0.0168,  ...,  0.0026, -0.0004, -0.0018],
        [ 0.0007,  0.0112,  0.0011,  ...,  0.0174, -0.0173, -0.0119],
        [ 0.0076, -0.0191, -0.0139,  ...,  0.0133, -0.0145,  0.0089],
        ...,
        [ 0.0058,  0.0056,  0.0008,  ..., -0.0223,  0.0184, -0.0135],
        [ 0.0024,  0.0226, -0.0018,  ..., -0.0059,  0.0126, -0.0071],
        [-0.0102,  0.0151, -0.0025,  ..., -0.0064, -0.0202, -0.0134]])

Weird!

@MagiaSN
Copy link
Author

MagiaSN commented Nov 24, 2023

My local branch has diverged from the main and contains many irrelevant changes, and I am trying to pick out the minimal necessary changes for your reference: #1624

@luo-li-ba-suo
Copy link

Hey, why is it not updated now, is it found that there will be any bugs?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request pending This problem is yet to be addressed.
Projects
None yet
Development

No branches or pull requests

3 participants