From 04de6d041f7fc568183342d3c2e4b907ec5c1ebf Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Thu, 27 Jun 2024 20:51:25 +0200 Subject: [PATCH 1/4] Simple unit test to check backward function for conv Signed-off-by: Xavier Dupre --- .../function_libs/torch_lib/backward_test.py | 71 +++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 onnxscript/function_libs/torch_lib/backward_test.py diff --git a/onnxscript/function_libs/torch_lib/backward_test.py b/onnxscript/function_libs/torch_lib/backward_test.py new file mode 100644 index 000000000..6b6eb7a62 --- /dev/null +++ b/onnxscript/function_libs/torch_lib/backward_test.py @@ -0,0 +1,71 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# pylint: disable=not-callable + +import copy +import sys +import unittest + +import torch + +import onnxscript.tools.training_helper +import onnxscript.tools.transformers_models +import onnxscript.tools.transformers_models.llama +from onnxscript._internal.version_utils import has_transformers, torch_older_than + + +class TestBackward(unittest.TestCase): + @unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows") + @unittest.skipIf(not has_transformers(), reason="transformers is missing") + @unittest.skipIf(torch_older_than("2.4"), reason="fails to export") + def test_backward_conv(self): + class SimpleCNNN(torch.nn.Module): + def __init__(self): + super().__init__() + + self.conv1 = torch.nn.Conv2d( + in_channels=1, out_channels=32, kernel_size=3, stride=2, padding=1 + ) # Output size 14x14 pour MNIST + self.conv2 = torch.nn.Conv2d( + in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1 + ) # Output size 7x7 pour MNIST + + self.fc1 = torch.nn.Linear(64 * 7 * 7, 128) + self.fc2 = torch.nn.Linear(128, 10) + + def forward(self, x): + x = torch.nn.functional.relu(self.conv1(x)) + x = torch.nn.functional.relu(self.conv2(x)) + + x = x.view(-1, 64 * 7 * 7) + + x = torch.nn.functional.relu(self.fc1(x)) + x = self.fc2(x) + + return x + + input_tensors = (torch.randn(1, 1, 32, 32),) + model = SimpleCNNN() + expected = model(*input_tensors) + + local_aot_ort = onnxscript.tools.training_helper.make_aot_ort(dynamic=False) + + compiled_model = torch.compile( + copy.deepcopy(model), + backend=local_aot_ort, + dynamic=False, + fullgraph=True, + ) + + results = compiled_model(*input_tensors) + torch.testing.assert_allclose(expected[0], results[0], atol=1e-5, rtol=1e-5) + + expected_gradients = onnxscript.tools.training_helper.train_loop(model, *input_tensors) + gradients = onnxscript.tools.training_helper.train_loop(compiled_model, *input_tensors) + torch.testing.assert_allclose( + expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5 + ) + + +if __name__ == "__main__": + unittest.main(verbosity=2) From 77dafc3b3b552a8f492904e80f7c1ed7475a5cdc Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Fri, 28 Jun 2024 11:09:58 +0200 Subject: [PATCH 2/4] complete Signed-off-by: Xavier Dupre --- .gitignore | 1 + .../function_libs/torch_lib/backward_test.py | 86 ++++++++++++++----- onnxscript/tools/training_helper.py | 67 ++++++++++++++- .../tools/transformers_models/llama_test.py | 19 ++-- .../tools/transformers_models/mistral_test.py | 23 +++-- .../tools/transformers_models/phi3_test.py | 23 +++-- .../tools/transformers_models/phi_test.py | 22 +++-- 7 files changed, 188 insertions(+), 53 deletions(-) diff --git a/.gitignore b/.gitignore index 9e6f1a45c..203580c52 100644 --- a/.gitignore +++ b/.gitignore @@ -42,6 +42,7 @@ coverage.xml cover/ test-output.xml *.sarif +_dump_* # Sphinx documentation docs/_build/ diff --git a/onnxscript/function_libs/torch_lib/backward_test.py b/onnxscript/function_libs/torch_lib/backward_test.py index 6b6eb7a62..0611196f6 100644 --- a/onnxscript/function_libs/torch_lib/backward_test.py +++ b/onnxscript/function_libs/torch_lib/backward_test.py @@ -3,6 +3,7 @@ # pylint: disable=not-callable import copy +import os import sys import unittest @@ -18,36 +19,66 @@ class TestBackward(unittest.TestCase): @unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows") @unittest.skipIf(not has_transformers(), reason="transformers is missing") @unittest.skipIf(torch_older_than("2.4"), reason="fails to export") - def test_backward_conv(self): + def test_backward_working(self): class SimpleCNNN(torch.nn.Module): def __init__(self): super().__init__() - self.conv1 = torch.nn.Conv2d( - in_channels=1, out_channels=32, kernel_size=3, stride=2, padding=1 - ) # Output size 14x14 pour MNIST - self.conv2 = torch.nn.Conv2d( - in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1 - ) # Output size 7x7 pour MNIST - - self.fc1 = torch.nn.Linear(64 * 7 * 7, 128) - self.fc2 = torch.nn.Linear(128, 10) + self.fc1 = torch.nn.Linear(14, 10) def forward(self, x): - x = torch.nn.functional.relu(self.conv1(x)) - x = torch.nn.functional.relu(self.conv2(x)) + return torch.nn.functional.relu(self.fc1(x)) - x = x.view(-1, 64 * 7 * 7) + input_tensors = (torch.randn(1, 1, 14, 14),) + model = SimpleCNNN() + local_aot_ort = onnxscript.tools.training_helper.make_aot_ort(dynamic=False) - x = torch.nn.functional.relu(self.fc1(x)) - x = self.fc2(x) + compiled_model = torch.compile( + copy.deepcopy(model), + backend=local_aot_ort, + dynamic=False, + fullgraph=True, + ) - return x + expected_results, expected_gradients = onnxscript.tools.training_helper.train_loop( + model, *input_tensors + ) + results, gradients, onnx_models = onnxscript.tools.training_helper.train_loop( + compiled_model, + *input_tensors, + dump_onnx_models=True, + dump_prefix="_dump_testbw_working", + dump_clean_first=True, + ) + torch.testing.assert_allclose(expected_results[0], results[0], atol=1e-5, rtol=1e-5) - input_tensors = (torch.randn(1, 1, 32, 32),) - model = SimpleCNNN() - expected = model(*input_tensors) + # Checking there is only two generated graphs otherwise, it means there are graph breaks. + self.assertEqual(len(onnx_models), 2) + torch.testing.assert_allclose( + expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5 + ) + + @unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows") + @unittest.skipIf(not has_transformers(), reason="transformers is missing") + @unittest.skipIf(torch_older_than("2.4"), reason="fails to export") + @unittest.skipIf(True, reason="aten.conv_backward not implemented yet.") + def test_backward_conv(self): + class SimpleCNNN(torch.nn.Module): + def __init__(self): + super().__init__() + + self.conv1 = torch.nn.Conv2d( + in_channels=1, out_channels=2, kernel_size=3, padding=1 + ) + self.fc1 = torch.nn.Linear(14, 10) + + def forward(self, x): + y = torch.nn.functional.relu(self.conv1(x)) + z = self.fc1(y) + return z + input_tensors = (torch.randn(1, 1, 14, 14),) + model = SimpleCNNN() local_aot_ort = onnxscript.tools.training_helper.make_aot_ort(dynamic=False) compiled_model = torch.compile( @@ -57,11 +88,20 @@ def forward(self, x): fullgraph=True, ) - results = compiled_model(*input_tensors) - torch.testing.assert_allclose(expected[0], results[0], atol=1e-5, rtol=1e-5) + expected_results, expected_gradients = onnxscript.tools.training_helper.train_loop( + model, *input_tensors + ) + results, gradients, onnx_models = onnxscript.tools.training_helper.train_loop( + compiled_model, + *input_tensors, + dump_onnx_models=True, + dump_prefix="_dump_testbw_conv", + dump_clean_first=True, + ) + torch.testing.assert_allclose(expected_results[0], results[0], atol=1e-5, rtol=1e-5) - expected_gradients = onnxscript.tools.training_helper.train_loop(model, *input_tensors) - gradients = onnxscript.tools.training_helper.train_loop(compiled_model, *input_tensors) + # Checking there is only two generated graphs otherwise, it means there are graph breaks. + self.assertEqual(len(onnx_models), 2) torch.testing.assert_allclose( expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5 ) diff --git a/onnxscript/tools/training_helper.py b/onnxscript/tools/training_helper.py index 785b2e6fb..6fe245ba0 100644 --- a/onnxscript/tools/training_helper.py +++ b/onnxscript/tools/training_helper.py @@ -2,13 +2,19 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- +from __future__ import annotations + +import glob +import os +from typing import Any + import torch from torch.onnx import ExportOptions from torch.onnx import _OrtBackend as OrtBackend from torch.onnx import _OrtBackendOptions as OrtBackendOptions -def make_aot_ort(dynamic: bool = False): +def make_aot_ort(dynamic: bool = False) -> Any: """Implements an autograd backend for torch.compile based on onnxrt backend.""" export_options = ExportOptions(dynamic_shapes=dynamic) options = OrtBackendOptions(export_options=export_options) @@ -16,8 +22,39 @@ def make_aot_ort(dynamic: bool = False): return ort_backend -def train_loop(model, *args, loss_fn=None, optimizer=None): - """Implements a training loop to be used in tests.""" +def train_loop( + model: Any, + *args, + loss_fn: Any | None = None, + optimizer: Any | None = None, + dump_onnx_models: bool = False, + dump_prefix: str = "dump_train_loop", + dump_clean_first: bool = True, +) -> tuple[Any, tuple[Any, ...]] | tuple[Any, tuple[Any, ...], list[str]]: + """Implements a training loop to be used in tests. + The function returns the forward output and gradients in a tuple. + + if dump_onnx_models is True, the function returns the forward output, + the gradients in a tuple and the generated onnx_files. + If there is no graph break, there should be + two graphs, one for forward, one for backward. + + Args: + model: pytorch model + args: inputs + loss_fn: loss function, default is MSELoss + optimizer: optimizer, default is SGD + dump_onnx_models: dumps the model onnxrt backend is producing + dump_prefix: names will be `0.onnx`, `1.onnx`, ... + dump_clean_first: clean all files starting with the given prefix + + Returns: + - the forward outputs + - the backwards gradients + - the dumped onnw models, 2 at least unless the forward, backward + were called before this function is executed or if the model + is not a compiled model + """ if loss_fn is None: loss_fn = torch.nn.MSELoss() @@ -28,6 +65,16 @@ def train_loop(model, *args, loss_fn=None, optimizer=None): # Unnecessary in this situation but added for best practices model.train() + if dump_onnx_models: + if dump_clean_first: + names = glob.glob(f"{dump_prefix}*") + for name in names: + os.remove(name) + + old_value = os.environ.get("ONNXRT_DUMP_PATH", None) + os.environ["ONNXRT_DUMP_PATH"] = f"{dump_prefix}_forward" + existing_files = glob.glob(f"{dump_prefix}*.onnx") + # Compute prediction and loss pred = model(*args) if isinstance(pred, tuple): @@ -39,6 +86,8 @@ def train_loop(model, *args, loss_fn=None, optimizer=None): loss = loss_fn(v, torch.ones_like(v)) # Backpropagation + if dump_onnx_models: + os.environ["ONNXRT_DUMP_PATH"] = f"{dump_prefix}_backward" loss.backward() optimizer.step() # skip that part to retrieve the gradients @@ -47,4 +96,14 @@ def train_loop(model, *args, loss_fn=None, optimizer=None): # returns the gradients res = tuple(p.grad for p in model.parameters() if p.grad is not None) assert len(res) > 0, f"No gradient, loss is {loss}" - return res + + if dump_onnx_models: + if old_value is None: + del os.environ["ONNXRT_DUMP_PATH"] + else: + os.environ["ONNXRT_DUMP_PATH"] = old_value + new_files = glob.glob(f"{dump_prefix}*.onnx") + added_files = set(new_files) - set(existing_files) + return pred, res, [f for f in new_files if f in added_files] + + return pred, res diff --git a/onnxscript/tools/transformers_models/llama_test.py b/onnxscript/tools/transformers_models/llama_test.py index ccfe722f9..a4365aacd 100644 --- a/onnxscript/tools/transformers_models/llama_test.py +++ b/onnxscript/tools/transformers_models/llama_test.py @@ -66,8 +66,6 @@ def test_llama_dort_static(self): onnxscript.tools.transformers_models.llama.get_llama_model() ) input_tensors = input_tensors_many[0] - expected = model(*input_tensors) - local_aot_ort = onnxscript.tools.training_helper.make_aot_ort(dynamic=False) compiled_model = torch.compile( @@ -77,11 +75,20 @@ def test_llama_dort_static(self): fullgraph=True, ) - results = compiled_model(*input_tensors) - torch.testing.assert_allclose(expected[0], results[0], atol=1e-5, rtol=1e-5) + expected_results, expected_gradients = onnxscript.tools.training_helper.train_loop( + model, *input_tensors + ) + results, gradients, onnx_models = onnxscript.tools.training_helper.train_loop( + compiled_model, + *input_tensors, + dump_onnx_models=True, + dump_prefix="_dump_dort_llama", + dump_clean_first=True, + ) + torch.testing.assert_allclose(expected_results[0], results[0], atol=1e-5, rtol=1e-5) - expected_gradients = onnxscript.tools.training_helper.train_loop(model, *input_tensors) - gradients = onnxscript.tools.training_helper.train_loop(compiled_model, *input_tensors) + # Checking there is only two generated graphs otherwise, it means there are graph breaks. + self.assertEqual(len(onnx_models), 2) torch.testing.assert_allclose( expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5 ) diff --git a/onnxscript/tools/transformers_models/mistral_test.py b/onnxscript/tools/transformers_models/mistral_test.py index f1885c950..0516af74b 100644 --- a/onnxscript/tools/transformers_models/mistral_test.py +++ b/onnxscript/tools/transformers_models/mistral_test.py @@ -71,8 +71,6 @@ def test_phi_dort_static(self): onnxscript.tools.transformers_models.mistral.get_mistral_model() ) input_tensors = input_tensors_many[0] - expected = model(*input_tensors) - local_aot_ort = onnxscript.tools.training_helper.make_aot_ort(dynamic=False) compiled_model = torch.compile( @@ -82,12 +80,23 @@ def test_phi_dort_static(self): fullgraph=True, ) - results = compiled_model(*input_tensors) - torch.testing.assert_close(expected[0], results[0], atol=1e-5, rtol=1e-5) + expected_results, expected_gradients = onnxscript.tools.training_helper.train_loop( + model, *input_tensors + ) + results, gradients, onnx_models = onnxscript.tools.training_helper.train_loop( + compiled_model, + *input_tensors, + dump_onnx_models=True, + dump_prefix="_dump_dort_mistral", + dump_clean_first=True, + ) + torch.testing.assert_allclose(expected_results[0], results[0], atol=1e-5, rtol=1e-5) - expected_gradients = onnxscript.tools.training_helper.train_loop(model, *input_tensors) - gradients = onnxscript.tools.training_helper.train_loop(compiled_model, *input_tensors) - torch.testing.assert_close(expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5) + # Checking there is only two generated graphs otherwise, it means there are graph breaks. + self.assertEqual(len(onnx_models), 2) + torch.testing.assert_allclose( + expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5 + ) if __name__ == "__main__": diff --git a/onnxscript/tools/transformers_models/phi3_test.py b/onnxscript/tools/transformers_models/phi3_test.py index 62bb6faf8..331f3066e 100644 --- a/onnxscript/tools/transformers_models/phi3_test.py +++ b/onnxscript/tools/transformers_models/phi3_test.py @@ -75,8 +75,6 @@ def test_phi3_dort_static(self): onnxscript.tools.transformers_models.phi3.get_phi3_model() ) input_tensors = input_tensors_many[0] - expected = model(*input_tensors) - local_aot_ort = onnxscript.tools.training_helper.make_aot_ort(dynamic=False) compiled_model = torch.compile( @@ -86,12 +84,23 @@ def test_phi3_dort_static(self): fullgraph=True, ) - results = compiled_model(*input_tensors) - torch.testing.assert_close(expected[0], results[0], atol=1e-5, rtol=1e-5) + expected_results, expected_gradients = onnxscript.tools.training_helper.train_loop( + model, *input_tensors + ) + results, gradients, onnx_models = onnxscript.tools.training_helper.train_loop( + compiled_model, + *input_tensors, + dump_onnx_models=True, + dump_prefix="_dump_dort_phi3", + dump_clean_first=True, + ) + torch.testing.assert_allclose(expected_results[0], results[0], atol=1e-5, rtol=1e-5) - expected_gradients = onnxscript.tools.training_helper.train_loop(model, *input_tensors) - gradients = onnxscript.tools.training_helper.train_loop(compiled_model, *input_tensors) - torch.testing.assert_close(expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5) + # Checking there is only two generated graphs otherwise, it means there are graph breaks. + self.assertEqual(len(onnx_models), 2) + torch.testing.assert_allclose( + expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5 + ) if __name__ == "__main__": diff --git a/onnxscript/tools/transformers_models/phi_test.py b/onnxscript/tools/transformers_models/phi_test.py index f67745a6d..0e7c01cb3 100644 --- a/onnxscript/tools/transformers_models/phi_test.py +++ b/onnxscript/tools/transformers_models/phi_test.py @@ -60,7 +60,6 @@ def test_phi_export_cuda(self): def test_phi_dort_static(self): model, input_tensors_many, _ = onnxscript.tools.transformers_models.phi.get_phi_model() input_tensors = input_tensors_many[0] - expected = model(*input_tensors) local_aot_ort = onnxscript.tools.training_helper.make_aot_ort(dynamic=False) @@ -71,12 +70,23 @@ def test_phi_dort_static(self): fullgraph=True, ) - results = compiled_model(*input_tensors) - torch.testing.assert_close(expected[0], results[0], atol=1e-5, rtol=1e-5) + expected_results, expected_gradients = onnxscript.tools.training_helper.train_loop( + model, *input_tensors + ) + results, gradients, onnx_models = onnxscript.tools.training_helper.train_loop( + compiled_model, + *input_tensors, + dump_onnx_models=True, + dump_prefix="_dump_dort_phi", + dump_clean_first=True, + ) + torch.testing.assert_allclose(expected_results[0], results[0], atol=1e-5, rtol=1e-5) - expected_gradients = onnxscript.tools.training_helper.train_loop(model, *input_tensors) - gradients = onnxscript.tools.training_helper.train_loop(compiled_model, *input_tensors) - torch.testing.assert_close(expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5) + # Checking there is only two generated graphs otherwise, it means there are graph breaks. + self.assertEqual(len(onnx_models), 2) + torch.testing.assert_allclose( + expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5 + ) if __name__ == "__main__": From 67efd9c036426838f722901b0342437b7b13f347 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Wed, 3 Jul 2024 10:37:16 +0200 Subject: [PATCH 3/4] fix merge conflict Signed-off-by: Xavier Dupre --- .../tools/transformers_models/llama_test.py | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/onnxscript/tools/transformers_models/llama_test.py b/onnxscript/tools/transformers_models/llama_test.py index ace233649..445c46fb0 100644 --- a/onnxscript/tools/transformers_models/llama_test.py +++ b/onnxscript/tools/transformers_models/llama_test.py @@ -129,12 +129,23 @@ def test_llama_dort_static(self): fullgraph=True, ) - results = compiled_model(*input_tensors) - torch.testing.assert_close(expected[0], results[0], atol=1e-5, rtol=1e-5) + expected_results, expected_gradients = onnxscript.tools.training_helper.train_loop( + model, *input_tensors + ) + results, gradients, onnx_models = onnxscript.tools.training_helper.train_loop( + compiled_model, + *input_tensors, + dump_onnx_models=True, + dump_prefix="_dump_dort_llama", + dump_clean_first=True, + ) + torch.testing.assert_allclose(expected_results[0], results[0], atol=1e-5, rtol=1e-5) - expected_gradients = onnxscript.tools.training_helper.train_loop(model, *input_tensors) - gradients = onnxscript.tools.training_helper.train_loop(compiled_model, *input_tensors) - torch.testing.assert_close(expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5) + # Checking there is only two generated graphs otherwise, it means there are graph breaks. + self.assertEqual(len(onnx_models), 2) + torch.testing.assert_allclose( + expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5 + ) if __name__ == "__main__": From a42d94e273365f0bb3b6aa701dae2e39165d5a58 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Wed, 3 Jul 2024 11:27:28 +0200 Subject: [PATCH 4/4] lint Signed-off-by: Xavier Dupre --- onnxscript/function_libs/torch_lib/backward_test.py | 2 +- onnxscript/tools/transformers_models/mistral_test.py | 2 +- onnxscript/tools/transformers_models/phi3_test.py | 2 +- onnxscript/tools/transformers_models/phi_test.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/backward_test.py b/onnxscript/function_libs/torch_lib/backward_test.py index 0611196f6..dda56493a 100644 --- a/onnxscript/function_libs/torch_lib/backward_test.py +++ b/onnxscript/function_libs/torch_lib/backward_test.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# pylint: disable=not-callable +# pylint: disable=not-callable, unbalanced-tuple-unpacking import copy import os diff --git a/onnxscript/tools/transformers_models/mistral_test.py b/onnxscript/tools/transformers_models/mistral_test.py index 27ed5e677..422da323e 100644 --- a/onnxscript/tools/transformers_models/mistral_test.py +++ b/onnxscript/tools/transformers_models/mistral_test.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# pylint: disable=not-callable +# pylint: disable=not-callable, unbalanced-tuple-unpacking import copy import sys diff --git a/onnxscript/tools/transformers_models/phi3_test.py b/onnxscript/tools/transformers_models/phi3_test.py index 072ab6a86..a49908f4a 100644 --- a/onnxscript/tools/transformers_models/phi3_test.py +++ b/onnxscript/tools/transformers_models/phi3_test.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# pylint: disable=not-callable +# pylint: disable=not-callable, unbalanced-tuple-unpacking import copy import sys diff --git a/onnxscript/tools/transformers_models/phi_test.py b/onnxscript/tools/transformers_models/phi_test.py index b6576a99a..99d534ddd 100644 --- a/onnxscript/tools/transformers_models/phi_test.py +++ b/onnxscript/tools/transformers_models/phi_test.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# pylint: disable=not-callable +# pylint: disable=not-callable, unbalanced-tuple-unpacking import copy import sys