68 строки
2.0 KiB
Python
68 строки
2.0 KiB
Python
# -*- coding: utf-8 -*-
|
|
# File: gpu.py
|
|
|
|
|
|
import os
|
|
|
|
from . import logger
|
|
from .concurrency import subproc_call
|
|
from .nvml import NVMLContext
|
|
from .utils import change_env
|
|
|
|
__all__ = ['change_gpu', 'get_nr_gpu', 'get_num_gpu']
|
|
|
|
|
|
def change_gpu(val):
|
|
"""
|
|
Args:
|
|
val: an integer, the index of the GPU or -1 to disable GPU.
|
|
|
|
Returns:
|
|
a context where ``CUDA_VISIBLE_DEVICES=val``.
|
|
"""
|
|
val = str(val)
|
|
if val == '-1':
|
|
val = ''
|
|
return change_env('CUDA_VISIBLE_DEVICES', val)
|
|
|
|
|
|
def get_num_gpu():
|
|
"""
|
|
Returns:
|
|
int: #available GPUs in CUDA_VISIBLE_DEVICES, or in the system.
|
|
"""
|
|
|
|
def warn_return(ret, message):
|
|
try:
|
|
import tensorflow as tf
|
|
except ImportError:
|
|
return ret
|
|
|
|
built_with_cuda = tf.test.is_built_with_cuda()
|
|
if not built_with_cuda and ret > 0:
|
|
logger.warn(message + "But TensorFlow was not built with CUDA support and could not use GPUs!")
|
|
return ret
|
|
|
|
env = os.environ.get('CUDA_VISIBLE_DEVICES', None)
|
|
if env:
|
|
return warn_return(len(env.split(',')), "Found non-empty CUDA_VISIBLE_DEVICES. ")
|
|
output, code = subproc_call("nvidia-smi -L", timeout=5)
|
|
if code == 0:
|
|
output = output.decode('utf-8')
|
|
return warn_return(len(output.strip().split('\n')), "Found nvidia-smi. ")
|
|
try:
|
|
# Use NVML to query device properties
|
|
with NVMLContext() as ctx:
|
|
return warn_return(ctx.num_devices(), "NVML found nvidia devices. ")
|
|
except Exception:
|
|
# Fallback
|
|
# Note this will initialize all GPUs and therefore has side effect
|
|
# https://github.com/tensorflow/tensorflow/issues/8136
|
|
logger.info("Loading local devices by TensorFlow ...")
|
|
from tensorflow.python.client import device_lib
|
|
local_device_protos = device_lib.list_local_devices()
|
|
return len([x.name for x in local_device_protos if x.device_type == 'GPU'])
|
|
|
|
|
|
get_nr_gpu = get_num_gpu
|