From 0996a10077219de0556281511fc02f3ab68002d5 Mon Sep 17 00:00:00 2001 From: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Date: Tue, 20 Feb 2024 12:06:46 +0000 Subject: [PATCH] Revert low cpu mem tie weights (#29135) * Revert "Add tie_weights() to LM heads and set bias in set_output_embeddings() (#28948)" This reverts commit 725f4ad1ccad4e1aeb309688706b56713070334b. * Revert "Patch to skip failing `test_save_load_low_cpu_mem_usage` tests (#29043)" This reverts commit 4156f517ce0f00e0b7842410542aad5fe37e73cf. --- src/transformers/models/bert/modeling_bert.py | 6 ------ .../models/big_bird/modeling_big_bird.py | 6 ------ .../models/blip/modeling_blip_text.py | 4 ---- src/transformers/models/ernie/modeling_ernie.py | 6 ------ .../models/layoutlm/modeling_layoutlm.py | 4 ---- .../models/markuplm/modeling_markuplm.py | 3 --- .../megatron_bert/modeling_megatron_bert.py | 6 ------ src/transformers/models/mpnet/modeling_mpnet.py | 4 ---- src/transformers/models/mra/modeling_mra.py | 4 ---- src/transformers/models/nezha/modeling_nezha.py | 5 ----- .../nystromformer/modeling_nystromformer.py | 4 ---- .../models/qdqbert/modeling_qdqbert.py | 5 ----- .../models/roc_bert/modeling_roc_bert.py | 6 ------ src/transformers/models/tapas/modeling_tapas.py | 4 ---- src/transformers/models/vilt/modeling_vilt.py | 4 ---- .../models/visual_bert/modeling_visual_bert.py | 4 ---- src/transformers/models/yoso/modeling_yoso.py | 4 ---- .../test_modeling_bert_generation.py | 6 ------ .../test_modeling_deformable_detr.py | 4 ---- tests/models/deta/test_modeling_deta.py | 4 ---- tests/models/fsmt/test_modeling_fsmt.py | 6 ------ tests/models/marian/test_modeling_marian.py | 6 ------ tests/models/musicgen/test_modeling_musicgen.py | 4 ---- tests/models/reformer/test_modeling_reformer.py | 12 ------------ .../test_modeling_xlm_roberta_xl.py | 6 ------ tests/test_modeling_common.py | 17 ----------------- 26 files changed, 144 deletions(-) diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index ea5bae4a8bb435..4c068c4d4f1d76 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -692,9 +692,6 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias - def _tie_weights(self): - self.decoder.bias = self.bias - def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -1065,7 +1062,6 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings - self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @@ -1175,7 +1171,6 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings - self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( @@ -1329,7 +1324,6 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings - self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index 6e3af915cf8b36..008985f760e867 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -1707,9 +1707,6 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias - def _tie_weights(self): - self.decoder.bias = self.bias - def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -2269,7 +2266,6 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings - self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=BigBirdForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @@ -2382,7 +2378,6 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings - self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) @@ -2524,7 +2519,6 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings - self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index fa9b1e0e4fc476..808c33f8104fc1 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -523,9 +523,6 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias - def _tie_weights(self): - self.decoder.bias = self.bias - def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -820,7 +817,6 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings - self.cls.predictions.bias = new_embeddings.bias def forward( self, diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index 1a1e49dcbf16a9..291ab6c54d1e50 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -608,9 +608,6 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias - def _tie_weights(self): - self.decoder.bias = self.bias - def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -998,7 +995,6 @@ def get_output_embeddings(self): # Copied from transformers.models.bert.modeling_bert.BertForPreTraining.set_output_embeddings def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings - self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=ErnieForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @@ -1113,7 +1109,6 @@ def get_output_embeddings(self): # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.set_output_embeddings def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings - self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( @@ -1274,7 +1269,6 @@ def get_output_embeddings(self): # Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.set_output_embeddings def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings - self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index 70d11573d9251e..c2ecede73d3955 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -589,9 +589,6 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias - def _tie_weights(self): - self.decoder.bias = self.bias - def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -872,7 +869,6 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings - self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index 8d95bcc0c169c5..24ca0c4972aaa0 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -318,9 +318,6 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias - def _tie_weights(self): - self.decoder.bias = self.bias - def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index 0fd9127bab2440..9111f937bc2a06 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -659,9 +659,6 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias - def _tie_weights(self): - self.decoder.bias = self.bias - def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -1026,7 +1023,6 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings - self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=MegatronBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @@ -1136,7 +1132,6 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings - self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) @@ -1295,7 +1290,6 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings - self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( diff --git a/src/transformers/models/mpnet/modeling_mpnet.py b/src/transformers/models/mpnet/modeling_mpnet.py index 43cfaa5e69a140..86194607e21750 100644 --- a/src/transformers/models/mpnet/modeling_mpnet.py +++ b/src/transformers/models/mpnet/modeling_mpnet.py @@ -587,7 +587,6 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.lm_head.decoder = new_embeddings - self.lm_head.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( @@ -660,9 +659,6 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias - def _tie_weights(self): - self.decoder.bias = self.bias - def forward(self, features, **kwargs): x = self.dense(features) x = gelu(x) diff --git a/src/transformers/models/mra/modeling_mra.py b/src/transformers/models/mra/modeling_mra.py index 9915db471ef308..6e33753817027c 100644 --- a/src/transformers/models/mra/modeling_mra.py +++ b/src/transformers/models/mra/modeling_mra.py @@ -810,9 +810,6 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias - def _tie_weights(self): - self.decoder.bias = self.bias - def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -1046,7 +1043,6 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings - self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(MRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( diff --git a/src/transformers/models/nezha/modeling_nezha.py b/src/transformers/models/nezha/modeling_nezha.py index 8fc2041e931ded..918a10b2759a2d 100644 --- a/src/transformers/models/nezha/modeling_nezha.py +++ b/src/transformers/models/nezha/modeling_nezha.py @@ -679,9 +679,6 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias - def _tie_weights(self): - self.decoder.bias = self.bias - def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -1047,7 +1044,6 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings - self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=NezhaForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @@ -1156,7 +1152,6 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings - self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( diff --git a/src/transformers/models/nystromformer/modeling_nystromformer.py b/src/transformers/models/nystromformer/modeling_nystromformer.py index 1bba9fb1f85bc3..950f8d27fa8e5a 100755 --- a/src/transformers/models/nystromformer/modeling_nystromformer.py +++ b/src/transformers/models/nystromformer/modeling_nystromformer.py @@ -428,9 +428,6 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias - def _tie_weights(self): - self.decoder.bias = self.bias - def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -669,7 +666,6 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings - self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(NYSTROMFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( diff --git a/src/transformers/models/qdqbert/modeling_qdqbert.py b/src/transformers/models/qdqbert/modeling_qdqbert.py index ff4b5441ea8084..8c610ecaedbfc4 100755 --- a/src/transformers/models/qdqbert/modeling_qdqbert.py +++ b/src/transformers/models/qdqbert/modeling_qdqbert.py @@ -683,9 +683,6 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias - def _tie_weights(self): - self.decoder.bias = self.bias - def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -1027,7 +1024,6 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings - self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) @@ -1194,7 +1190,6 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings - self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index ded234b71cb6d5..f3de92fed38941 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -744,9 +744,6 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias - def _tie_weights(self): - self.decoder.bias = self.bias - def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -1093,7 +1090,6 @@ def get_output_embeddings(self): # Copied from transformers.models.bert.modeling_bert.BertForPreTraining.set_output_embeddings def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings - self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(ROC_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) @@ -1286,7 +1282,6 @@ def get_output_embeddings(self): # Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.set_output_embeddings def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings - self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(ROC_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) def forward( @@ -1424,7 +1419,6 @@ def get_output_embeddings(self): # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.set_output_embeddings def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings - self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(ROC_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index 1ee233ea9d7f6d..1e7a4372bb015e 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -729,9 +729,6 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias - def _tie_weights(self): - self.decoder.bias = self.bias - def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -1011,7 +1008,6 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings - self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(TAPAS_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py index 5e53d4332bd30e..9ffa9fff013c88 100755 --- a/src/transformers/models/vilt/modeling_vilt.py +++ b/src/transformers/models/vilt/modeling_vilt.py @@ -896,7 +896,6 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.mlm_score.decoder = new_embeddings - self.mlm_score.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(VILT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) @@ -1043,9 +1042,6 @@ def __init__(self, config, weight=None): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias - def _tie_weights(self): - self.decoder.bias = self.bias - def forward(self, x): x = self.transform(x) x = self.decoder(x) diff --git a/src/transformers/models/visual_bert/modeling_visual_bert.py b/src/transformers/models/visual_bert/modeling_visual_bert.py index 68e77505e12865..4af7696fc39634 100755 --- a/src/transformers/models/visual_bert/modeling_visual_bert.py +++ b/src/transformers/models/visual_bert/modeling_visual_bert.py @@ -499,9 +499,6 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias - def _tie_weights(self): - self.decoder.bias = self.bias - def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -882,7 +879,6 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings - self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=VisualBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) diff --git a/src/transformers/models/yoso/modeling_yoso.py b/src/transformers/models/yoso/modeling_yoso.py index ab6fb1c151c0db..5361adc3ed48e0 100644 --- a/src/transformers/models/yoso/modeling_yoso.py +++ b/src/transformers/models/yoso/modeling_yoso.py @@ -626,9 +626,6 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias - def _tie_weights(self): - self.decoder.bias = self.bias - def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -867,7 +864,6 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings - self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(YOSO_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( diff --git a/tests/models/bert_generation/test_modeling_bert_generation.py b/tests/models/bert_generation/test_modeling_bert_generation.py index 4e0e3dc8e1c9f8..ecd7a459e0ea8d 100644 --- a/tests/models/bert_generation/test_modeling_bert_generation.py +++ b/tests/models/bert_generation/test_modeling_bert_generation.py @@ -305,12 +305,6 @@ def test_model_from_pretrained(self): model = BertGenerationEncoder.from_pretrained("google/bert_for_seq_generation_L-24_bbc_encoder") self.assertIsNotNone(model) - @unittest.skip( - "Not currently compatible. Fails with - NotImplementedError: Cannot copy out of meta tensor; no data!" - ) - def test_save_load_low_cpu_mem_usage(self): - pass - @require_torch class BertGenerationEncoderIntegrationTest(unittest.TestCase): diff --git a/tests/models/deformable_detr/test_modeling_deformable_detr.py b/tests/models/deformable_detr/test_modeling_deformable_detr.py index c1268fff3c6e64..5b123884e9cc53 100644 --- a/tests/models/deformable_detr/test_modeling_deformable_detr.py +++ b/tests/models/deformable_detr/test_modeling_deformable_detr.py @@ -564,10 +564,6 @@ def test_initialization(self): msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) - @unittest.skip("Cannot be initialized on meta device as some weights are modified during the initialization") - def test_save_load_low_cpu_mem_usage(self): - pass - def test_two_stage_training(self): model_class = DeformableDetrForObjectDetection config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/deta/test_modeling_deta.py b/tests/models/deta/test_modeling_deta.py index ffebfd38d0eba3..3a3a957dd012e2 100644 --- a/tests/models/deta/test_modeling_deta.py +++ b/tests/models/deta/test_modeling_deta.py @@ -520,10 +520,6 @@ def test_initialization(self): msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) - @unittest.skip("Cannot be initialized on meta device as some weights are modified during the initialization") - def test_save_load_low_cpu_mem_usage(self): - pass - TOLERANCE = 1e-4 diff --git a/tests/models/fsmt/test_modeling_fsmt.py b/tests/models/fsmt/test_modeling_fsmt.py index 18ee40e471ae9f..da73b8d41d9902 100644 --- a/tests/models/fsmt/test_modeling_fsmt.py +++ b/tests/models/fsmt/test_modeling_fsmt.py @@ -329,12 +329,6 @@ def test_tie_model_weights(self): def test_resize_embeddings_untied(self): pass - @unittest.skip( - "Not currently compatible. Fails with - NotImplementedError: Cannot copy out of meta tensor; no data!" - ) - def test_save_load_low_cpu_mem_usage(self): - pass - @require_torch class FSMTHeadTests(unittest.TestCase): diff --git a/tests/models/marian/test_modeling_marian.py b/tests/models/marian/test_modeling_marian.py index e393c7d10325a8..53a67c20459f58 100644 --- a/tests/models/marian/test_modeling_marian.py +++ b/tests/models/marian/test_modeling_marian.py @@ -372,12 +372,6 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip( - "Not currently compatible. Fails with - NotImplementedError: Cannot copy out of meta tensor; no data!" - ) - def test_save_load_low_cpu_mem_usage(self): - pass - def assert_tensors_close(a, b, atol=1e-12, prefix=""): """If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error.""" diff --git a/tests/models/musicgen/test_modeling_musicgen.py b/tests/models/musicgen/test_modeling_musicgen.py index 284450a00af5f9..b7952d27a71592 100644 --- a/tests/models/musicgen/test_modeling_musicgen.py +++ b/tests/models/musicgen/test_modeling_musicgen.py @@ -1144,10 +1144,6 @@ def test_greedy_generate_stereo_outputs(self): self.assertNotIn(config.pad_token_id, output_generate) - @unittest.skip("Fails with - TypeError: _weight_norm_interface() missing 1 required positional argument: 'dim'") - def test_save_load_low_cpu_mem_usage(self): - pass - def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000): """Produces a series of 'bip bip' sounds at a given frequency.""" diff --git a/tests/models/reformer/test_modeling_reformer.py b/tests/models/reformer/test_modeling_reformer.py index b1796a6c534d4e..11cd7e1a33b45a 100644 --- a/tests/models/reformer/test_modeling_reformer.py +++ b/tests/models/reformer/test_modeling_reformer.py @@ -687,12 +687,6 @@ def _check_hidden_states_for_generate( def test_left_padding_compatibility(self): pass - @unittest.skip( - "Not currently compatible. Fails with - NotImplementedError: Cannot copy out of meta tensor; no data!" - ) - def test_save_load_low_cpu_mem_usage(self): - pass - @require_torch class ReformerLSHAttnModelTest( @@ -854,12 +848,6 @@ def test_past_key_values_format(self): def test_left_padding_compatibility(self): pass - @unittest.skip( - "Not currently compatible. Fails with - NotImplementedError: Cannot copy out of meta tensor; no data!" - ) - def test_save_load_low_cpu_mem_usage(self): - pass - @require_torch @require_sentencepiece diff --git a/tests/models/xlm_roberta_xl/test_modeling_xlm_roberta_xl.py b/tests/models/xlm_roberta_xl/test_modeling_xlm_roberta_xl.py index c6513ef79628bd..828d6a02a6a368 100644 --- a/tests/models/xlm_roberta_xl/test_modeling_xlm_roberta_xl.py +++ b/tests/models/xlm_roberta_xl/test_modeling_xlm_roberta_xl.py @@ -515,12 +515,6 @@ def test_create_position_ids_from_inputs_embeds(self): self.assertEqual(position_ids.shape, expected_positions.shape) self.assertTrue(torch.all(torch.eq(position_ids, expected_positions))) - @unittest.skip( - "Not currently compatible. Fails with - NotImplementedError: Cannot copy out of meta tensor; no data!" - ) - def test_save_load_low_cpu_mem_usage(self): - pass - @require_torch class XLMRobertaModelXLIntegrationTest(unittest.TestCase): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index dfe613fa1fd7db..32f6abcbe3aad1 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -435,23 +435,6 @@ class CopyClass(model_class): max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item() self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") - def test_save_load_low_cpu_mem_usage(self): - with tempfile.TemporaryDirectory() as tmpdirname: - for model_class in self.all_model_classes: - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - model_to_save = model_class(config) - - model_to_save.save_pretrained(tmpdirname) - - model = model_class.from_pretrained( - tmpdirname, - low_cpu_mem_usage=True, - ) - - # The low_cpu_mem_usage=True causes the model params to be initialized with device=meta. If there are - # any unloaded or untied parameters, then trying to move it to device=torch_device will throw an error. - model.to(torch_device) - def test_fast_init_context_manager(self): # 1. Create a dummy class. Should have buffers as well? To make sure we test __init__ class MyClass(PreTrainedModel):