Skip to content

Commit

Permalink
[GPT2] Add correct keys on _keys_to_ignore_on_load_unexpected on …
Browse files Browse the repository at this point in the history
…all child classes of `GPT2PreTrainedModel` (#24113)

* add correct keys on `_keys_to_ignore_on_load_unexpected`

* oops
  • Loading branch information
younesbelkada authored and sgugger committed Jun 8, 2023
1 parent b3e27a8 commit fe861e5
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,8 @@ class GPT2DoubleHeadsModelOutput(ModelOutput):
GPT2_START_DOCSTRING,
)
class GPT2Model(GPT2PreTrainedModel):
_keys_to_ignore_on_load_missing = ["attn.masked_bias"]
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.bias", r"h\.\d+\.attn\.masked_bias"]
_keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias"]

def __init__(self, config):
super().__init__(config)
Expand Down Expand Up @@ -1149,6 +1150,7 @@ def _reorder_cache(
GPT2_START_DOCSTRING,
)
class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.bias", r"h\.\d+\.attn\.masked_bias"]
_keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]

def __init__(self, config):
Expand Down Expand Up @@ -1377,6 +1379,7 @@ def _reorder_cache(
GPT2_START_DOCSTRING,
)
class GPT2ForSequenceClassification(GPT2PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.bias", r"h\.\d+\.attn\.masked_bias"]
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head.weight"]

def __init__(self, config):
Expand Down Expand Up @@ -1600,6 +1603,7 @@ def forward(
GPT2_START_DOCSTRING,
)
class GPT2ForQuestionAnswering(GPT2PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.bias", r"h\.\d+\.attn\.masked_bias"]
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head.weight"]

def __init__(self, config):
Expand Down

0 comments on commit fe861e5

Please sign in to comment.