Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Returned fixed point solution sometimes has requires_grad=False #8

Open
Boshi-Wang opened this issue Jun 27, 2024 · 0 comments
Open

Comments

@Boshi-Wang
Copy link

Boshi-Wang commented Jun 27, 2024

Thanks for the great library!

I'm learning to use the library by writing a DEQ layer inside GPT-2. After loss.backward() some model parameters do not get gradients. I tracked the issue and it turns out that, the returned fixed point solution z_out[-1] by the DEQ layer (which is a tensor of shape (batch size, sequence length, hidden dimension)) has requires_grad set to False. Strangely, if I use a freshly-initialized GPT-2 model instead (without the pretrained weights), the issue is gone.

Specifically, this is my deq setup:

self.deq = get_deq(
    ift=True,
    f_solver='broyden', f_max_iter=15, f_tol=1e-3, f_stop_mode='rel',
    b_solver='broyden', b_max_iter=15, b_tol=1e-3, b_stop_mode='rel',
    )

if I print the requires_grad of the solution in each iteration before and after the implicit function, such as:

def layer_iter(hidden_states, input_hidden_states):
    print("before:", hidden_states.requires_grad)
    # forward pass through each transformer block, where input_hidden_states is the input embedding
    # ...
    print("after:", hidden_states.requires_grad)
    print("----")
    return hidden_states

func = lambda var: layer_iter(var, hidden_states_)    # hidden_states_ is the input after the embedding layer
zeros_ = torch.zeros(*hidden_states_.shape, requires_grad=True).to(hidden_states_.device)
z_out, info = self.deq(func, zeros_)
print("z_out[-1].requires_grad:", z_out[-1].requires_grad)
print(info)
hidden_states = z_out[-1]

and with the main scripts as

import torch
from torchdeq import get_deq
import matplotlib.pyplot as plt
from transformers import GPT2Tokenizer, GPT2Model

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model_name = 'gpt2'
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model = GPT2Model.from_pretrained(model_name, attn_pdrop=0.0, embd_pdrop=0.0, resid_pdrop=0.0, summary_first_dropout=0.0)
model.to(device)

batch = ["we", "we"]
inputs = tokenizer(batch, return_tensors='pt', padding=True, truncation=True)
inputs = {key: value.to(device) for key, value in inputs.items()}

outputs = model(**inputs, use_cache=False)

Then I will get

before: True
after: False
----
before: False
after: False
----
before: False
after: False
----
before: False
after: False
----
before: False
after: False
----
before: False
after: False
----
before: False
after: False
----
before: False
after: False
----
before: False
after: False
----
before: False
after: False
----
z_out[-1].requires_grad: False
{'abs_lowest': tensor([914.4200, 914.4200], device='cuda:0'), 'rel_lowest': tensor([6.5291e-05, 6.5291e-05], device='cuda:0'), 'abs_trace': tensor([[1.0000e+08, 1.4693e+03, 1.1808e+03, 2.8000e+03, 2.1719e+03, 9.1442e+02,
         1.0195e+03, 1.0512e+03, 1.0572e+03, 1.0586e+03, 9.1442e+02, 9.1442e+02,
         9.1442e+02, 9.1442e+02, 9.1442e+02, 9.1442e+02],
        [1.0000e+08, 1.4693e+03, 1.1808e+03, 2.8000e+03, 2.1719e+03, 9.1442e+02,
         1.0195e+03, 1.0512e+03, 1.0572e+03, 1.0586e+03, 9.1442e+02, 9.1442e+02,
         9.1442e+02, 9.1442e+02, 9.1442e+02, 9.1442e+02]], device='cuda:0'), 'rel_trace': tensor([[1.0000e+08, 1.3185e+00, 1.0195e+00, 6.6233e-01, 7.3519e-01, 2.3931e-01,
         4.5128e-01, 1.2702e-02, 2.3903e-03, 6.5291e-05, 6.5291e-05, 6.5291e-05,
         6.5291e-05, 6.5291e-05, 6.5291e-05, 6.5291e-05],
        [1.0000e+08, 1.3185e+00, 1.0195e+00, 6.6233e-01, 7.3519e-01, 2.3931e-01,
         4.5128e-01, 1.2702e-02, 2.3903e-03, 6.5291e-05, 6.5291e-05, 6.5291e-05,
         6.5291e-05, 6.5291e-05, 6.5291e-05, 6.5291e-05]], device='cuda:0'), 'nstep': tensor([9., 9.], device='cuda:0'), 'sradius': tensor([0.])}

where it seems that the model does converge in 9 iterations but the returned hidden states have requires_grad=False, so the model's parameters (besides those after the DEQ layer) do not get gradients. I tried manually setting z_out[-1].requires_grad=True but this doesn't help; after loss.backward() the .grad is still None for those parameters.

Intriguingly, if I use a freshly-initialized GPT-2 then the issue seems to go away:

model_name = 'gpt2'
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model = GPT2Model.from_pretrained(model_name, attn_pdrop=0.0, embd_pdrop=0.0, resid_pdrop=0.0, summary_first_dropout=0.0)

# initialize the weights randomly
config = model.config
model = GPT2Model(config)

batch = ["we", "we"]
inputs = tokenizer(batch, return_tensors='pt', padding=True, truncation=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
inputs = {key: value.to(device) for key, value in inputs.items()}
model.to(device)

outputs = model(**inputs, use_cache=False)

and I get

before: True
after: False
----
before: False
after: False
----
before: False
after: False
----
before: False
after: False
----
before: False
after: False
----
before: False
after: False
----
before: False
after: False
----
before: False
after: True
----
z_out[-1].requires_grad: True
{'abs_lowest': tensor([9.0937, 9.0937], device='cuda:0'), 'rel_lowest': tensor([0.0002, 0.0002], device='cuda:0'), 'abs_trace': tensor([[1.0000e+08, 9.0937e+00, 9.2320e+00, 9.3630e+00, 9.4192e+00, 9.4344e+00,
         9.4367e+00, 9.0937e+00, 9.0937e+00, 9.0937e+00, 9.0937e+00, 9.0937e+00,
         9.0937e+00, 9.0937e+00, 9.0937e+00, 9.0937e+00],
        [1.0000e+08, 9.0937e+00, 9.2320e+00, 9.3630e+00, 9.4192e+00, 9.4344e+00,
         9.4367e+00, 9.0937e+00, 9.0937e+00, 9.0937e+00, 9.0937e+00, 9.0937e+00,
         9.0937e+00, 9.0937e+00, 9.0937e+00, 9.0937e+00]], device='cuda:0'), 'rel_trace': tensor([[1.0000e+08, 5.5195e-01, 2.3347e-01, 9.2865e-02, 2.5290e-02, 3.8527e-03,
         1.9339e-04, 1.9339e-04, 1.9339e-04, 1.9339e-04, 1.9339e-04, 1.9339e-04,
         1.9339e-04, 1.9339e-04, 1.9339e-04, 1.9339e-04],
        [1.0000e+08, 5.5195e-01, 2.3347e-01, 9.2865e-02, 2.5290e-02, 3.8527e-03,
         1.9339e-04, 1.9339e-04, 1.9339e-04, 1.9339e-04, 1.9339e-04, 1.9339e-04,
         1.9339e-04, 1.9339e-04, 1.9339e-04, 1.9339e-04]], device='cuda:0'), 'nstep': tensor([6., 6.], device='cuda:0'), 'sradius': tensor([-1.])}
max/min/mean steps: 6.0, 6.0, 6.0

where it can be seen that requires_grad becomes True now. Also, apart from requires_grad, it seems the sradius in info is -1 now instead of 0, which I'm not sure if it's related here.

I wonder if you have ideas about why this happens. Would appreciate it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant