scripts/check-torchcuda.py: add checking for tensor.is_cuda (#3843)

.cpp files are excluded

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
This commit is contained in:
Guo Yejun 2023-07-01 06:06:36 +08:00 коммит произвёл GitHub
Родитель 2ded2ff0be
Коммит 807d1b5dfc
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 31 добавлений и 3 удалений

Просмотреть файл

@ -6,6 +6,7 @@
import torch
import triton
import triton.language as tl
from deepspeed.accelerator import get_accelerator
@triton.jit
@ -28,7 +29,7 @@ def gelu_kernel(x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
def gelu(activations: torch.Tensor) -> torch.Tensor:
assert activations.is_contiguous()
assert activations.is_cuda
assert get_accelerator().on_accelerator(activations)
output = torch.empty_like(activations)
n_elements = output.numel()

Просмотреть файл

@ -6,6 +6,7 @@
import torch
import triton
import triton.language as tl
from deepspeed.accelerator import get_accelerator
@triton.jit
@ -57,8 +58,11 @@ def residual_add_bias(hidden_state: torch.Tensor, residual: torch.Tensor, attn_o
attn_bias: torch.Tensor, final_bias: torch.Tensor, mp_size: int, mlp_after_attn: bool,
add_attn_bias: bool, pre_attn_norm: bool):
# check that all tensors are on the same device
assert hidden_state.is_cuda and residual.is_cuda and attn_output.is_cuda \
and attn_bias.is_cuda and final_bias.is_cuda
assert get_accelerator().on_accelerator(hidden_state) \
and get_accelerator().on_accelerator(residual) \
and get_accelerator().on_accelerator(attn_output) \
and get_accelerator().on_accelerator(attn_bias) \
and get_accelerator().on_accelerator(final_bias)
# check that all tensors have the same dtype
assert hidden_state.dtype == residual.dtype == attn_output.dtype \

Просмотреть файл

@ -52,3 +52,26 @@ elif res.returncode == 2:
err(f"Error invoking grep on {', '.join(sys.argv[1:])}:")
err(res.stderr.decode("utf-8"))
sys.exit(2)
files = []
for file in sys.argv[1:]:
if not file.endswith(".cpp"):
files.append(file)
res = subprocess.run(
["git", "grep", "-Hn", "--no-index", r"\.is_cuda", *files],
capture_output=True,
)
if res.returncode == 0:
err('''
Error: The string ".is_cuda" was found. This implies checking if a tensor is a cuda tensor.
Please replace all calls to "tensor.is_cuda" with "get_accelerator().on_accelerator(tensor)",
and add the following import line:
'from deepspeed.accelerator import get_accelerator'
''')
err(res.stdout.decode("utf-8"))
sys.exit(1)
elif res.returncode == 2:
err(f"Error invoking grep on {', '.join(files)}:")
err(res.stderr.decode("utf-8"))
sys.exit(2)