DirectML/PyTorch/README.md

2.3 KiB

PyTorch with DirectML Samples

PyTorch with DirectML enables training and inference of complex machine learning models on a wide range of DirectX 12-compatible hardware. This is done through torch-directml, a plugin for PyTorch.

DirectML is a high-performance, hardware-accelerated DirectX 12 library for machine learning. DirectML provides GPU acceleration for common machine learning tasks across a broad range of supported hardware and drivers, including all DirectX 12-capable GPUs from vendors such as AMD, Intel, NVIDIA, and Qualcomm.

More information about DirectML can be found on the DirectML Overview page on Microsoft Learn.

Setup

PyTorch with DirectML is supported on both the latest versions of Windows and the Windows Subsystem for Linux, and is available for download as a PyPI package. For more information about getting started with torch-directml, see our Windows or WSL 2 guidance on Microsoft Learn.

Once a Python (3.8 to 3.12) environment is setup, install the latest release of torch-directml by running the following command:

pip install torch-directml

Samples

Try the torch-directml samples below, or explore the cv, transformer, llm and diffusion folders: