Skip to content

v0.3.1

Compare
Choose a tag to compare
@supriyar supriyar released this 26 Jun 20:36
· 228 commits to main since this release

v0.3.1

Highlights

We are excited to announce the 0.3 release of torchao! This release adds support for a new quantize API, MX format, FP6 dtype and bitpacking, 2:4 sparse accelerated training and benchmarking infra for llama2/llama3 models.

quantize API (#256)

We added a tensor subclass based quantization API, see docs and README for details on usage, this is planned to replace all existing quantization APIs in torchao for torch 2.4 and later.

Accelerated training with 2:4 sparsity (#184)

You can now accelerate training with 2:4 sparsity, using the runtime pruning + compression kernels written by xFormers. These kernels process a 4x4 sub-tile to be 2:4 sparse in both directions, to handle both the forward and backward pass when training. We see a 1.3x speedup for the MLP layers of ViT-L across a forward and backwards pass.

MX support (#264)

We added prototype support for MX format for training and inference with a reference native PyTorch implementation of training and inference primitives for using MX accelerated matrix multiplications. The MX numerical formats are new low precision formats with recent acceptance into the OCP spec:
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf

Benchmarking (#276, #374)

We added a stable way to benchmark llama2 and llama3 models that includes perf/accuracy comparisons. See torchao/_models/llama/benchmarks.sh for more details.

🌟 💥 Community Contributions 🌟 💥

FP6 support (#279, #283, #358)

@gau-nernst Added support for FP6 dtype and mixed matmul FP16 x FP6 kernel with support for torch.compile. Benchmark results show a 2.3x speedup over BF16 baseline for meta-llama/Llama-2-7b-chat-hf

Bitpacking (#307, #282)

@vayuda, @melvinebenezer @CoffeeVampir3 @andreaskoepf Added support for packing/unpacking lower bit dtypes leveraging torch.compile to generate the kernels for this and added UInt2 and Bitnet tensor based on this approach.

FP8 split-gemm kernel #263

Added the kernel written by @AdnanHoque to torchao with speedups compared to the cuBLAS kernel for batch size <=16

BC Breaking

Deprecations

  • Deprecate top level quantization APIs #344

1. int8 weight only quantization

apply_weight_only_int8_quant(model) or change_linear_weights_to_int8_woqtensors(model)

-->

# for torch 2.4+
from torchao.quantization import quantize, int8_weight_only
quantize(model, int8_weight_only())

# for torch 2.2.2 and 2.3
from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors
change_linear_weights_to_int8_woqtensors(model)

2. int8 dynamic quantization

apply_dynamic_quant(model) or change_linear_weights_to_int8_dqtensors(model)

-->

# Fuse the int8*int8 -> int32 matmul and subsequent mul op avoiding materialization of the int32 intermediary tensor
torch._inductor.config.force_fuse_int_mm_with_mul = True

# for torch 2.4+
from torchao.quantization import quantize, int8_dynamic_activation_int8_weight
quantize(model, int8_dynamic_activation_int8_weight())

# for torch 2.2.2 and 2.3
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
change_linear_weights_to_int8_dqtensors(model)

3. int4 weight only quantization

change_linear_weights_to_int4_wotensors(model)

-->

# for torch 2.4+
from torchao.quantization import quantize, int4_weight_only
quantize(model, int4_weight_only())

# for torch 2.2.2 and 2.3
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors
change_linear_weights_to_int4_woqtensors(model)

New Features

  • Add quantize #256
  • Add a prototype of MX format training and inference #264
  • [FP6-LLM] Port splitK map from DeepSpeed #283
  • Improve FP6-LLM 2+4bit weight splitting + user API #279
  • Bitpacking #291
  • training acceleration via runtime semi-structured sparsity #184
  • Bitpackingv2 #307
  • Add FP6-LLM doc and move FP6-LLM to prototype #358
  • Added first bits of Uint2Tensor and BitnetTensor #282

Improvements

  • Improve primitives for FP6 quant #248
  • Extract eval code from GPTQ for more general usage #275
  • Factor out the specific configurations to helper functions #286
  • Add support for AQTLayout, PlainAQTLayout and TensorCoreTiledAQTLayout #278
  • Graceful handling of cpp extensions #296
  • Refactor int8 dynamic quantization with call to quantize #294
  • [NF4][FSDP] return contiguous quantization_factor #298
  • Refactor int4 and int8 weight only quantization to use quantize #301
  • Adding a quick way for users to test model eval for hf models #328
  • Wrap torch.ops.quantized_decomposed to improve import errors #310
  • [NF4Tensor] Switch to save for backward since are now a tensor input #323
  • Refactor rest of tinygemm quant primitive ops #321
  • Move some util functions from quantization.utils to torchao.utils #337
  • Clean up FP6-LLM #304
  • Move quant ops to utils.py #331
  • FP6-LLM clean up (again) #339
  • Improving hf_eval.py #342
  • Generalize Model Size Code #364
  • Minor upgrades to bit pack #347
  • Factor out dispatch and layout registration table #360
  • Add register_apply_tensor_subclass #366
  • Refactor custom FPx cast #363
  • Remove all dependencies except torch #369
  • Enable a test for loading state_dict with tensor subclasses #389
  • 073 scripts for benchmarks #372
  • Add WOQ int8 test with Inductor Freeze #362
  • Benchmarking updates for semi-structured sparse training #398
  • add FSDP QLoRA test and revert failing PR #403
  • Refactor the API for quant method argument for quantize function #400
  • eval script fixes #414

Bug Fixes

  • Fixed the HQQ import skip #262
  • fixing autoquant bug #265
  • Fix eval import after #275 #290
  • Fixed f-string printing of NF4Tensors #297
  • Check and fix dequantize_affine is idempotent #309
  • Update old pretrained TorchVision API in ao tutorials (#313) #314
  • Fix dimension issues for int4 weight only quant path #330
  • Fix compile in hf_eval.py #341
  • task_list to tasks in hf_eval #343
  • fixing peak memory stats for benchmark #353
  • Fix inductor config BC change #382
  • fixing scripts #395

Performance

  • FP8 splitgemm user defined triton kernel #263
  • sparse benchmarking numbers #303
  • Fix FP6-LLM benchmark #312
  • Adding Llama to TorchAO #276
  • Generalize Model Size Code #364
  • eval script for llama #374
  • 077 autoquant gpt fast #361

Docs

  • add static folder for images + fix links #271
  • Fix Readme and remove unused kernel #270
  • Kernel docs #274
  • Quantization Docstrings #273
  • Add AffineQuantizedTensor based workflow doc and examples #277
  • Add AUTOQUANT_CACHE docs for reusing the same quantization plan #329
  • Update nightly build instructions #334
  • add link to benchmarking script #355
  • New README #392
  • Minor README updates #401
  • Add quantize to doc page #367
  • Add link to new custom op tutorial #424

Devs

  • ci: Add push trigger for binary build workflows #259
  • Make fp8 test explicit #266
  • Move AffineQuantizedTensor to torchao/dtypes #272
  • Add suffix to package version #293
  • Re-enable AOTI tests #212
  • Add fused QKV HQQ triton_mm test #306
  • Pin CUDA nightly to mitigate regression #322
  • Unpin CUDA nightly #333
  • Add architecture to index postfix for nightly builds #336
  • Update regression test to python 3.8 #340
  • Remove test_ops.py warning spew #267
  • Add torchao.version #359
  • make torchao test discovery pass in fbcode #351
  • use pytorch version env variable #373
  • Update pre_build_script.sh #390
  • Add support for building CUDA extension on Windows #396
  • Add trymerge #388
  • Fix github CI error #409
  • Fix missing dependencies in trymerge workflow #413
  • Setup trymerge secrets #416
  • Pin CUDA nightlies for mx failures #428
  • fix mx triton kernel after PyTorch triton pin change #431

Untopiced

  • Print the code when the check failed #254
  • Retry of D58015187 Move AsyncCompile to a different file by @jamesjwu in #302
  • Revert "Clean up FP6-LLM" #338
  • Update version to 0.3.0 #348
  • Add torchao.version #359

New Contributors

Full Changelog: v0.2.0...v0.3.0-rc1

We were able to close about 60% of tasks for 0.3.0, which will now spill over into upcoming releases. We will post a list for 0.4.0 next, which we aim to release at the end of July 2024. We want to follow a monthly release cadence until further notice.

EDIT: We made a patch release for 0.3.1 to include 2 more PRs so now ao has no runtime dependencies #449 and #455