Merge pull request #77 from kikusui6192/Update_saliency.py

Added arg device in get_saliency and get_image_saliency_results
This commit is contained in:
Shital Shah 2023-08-29 23:59:12 -07:00 коммит произвёл GitHub
Родитель 2031c9222d 654b7eb7ee
Коммит c9e5e9c0c4
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 6 добавлений и 5 удалений

Просмотреть файл

@ -55,8 +55,9 @@ def _get_layer_path(model):
else: #TODO: guess layer for other networks?
return None
def get_saliency(model, raw_input, input, label, method='integrate_grad', layer_path=None):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def get_saliency(model, raw_input, input, label, device=None, method='integrate_grad', layer_path=None):
if device == None or type(device) != torch.device:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
input = input.to(device)
@ -71,7 +72,6 @@ def get_saliency(model, raw_input, input, label, method='integrate_grad', layer_
model.zero_grad()
layer_path = layer_path or _get_layer_path(model)
exp = _get_explainer(method, model, layer_path)
saliency = exp.explain(input, label, raw_input)
@ -85,12 +85,13 @@ def get_saliency(model, raw_input, input, label, method='integrate_grad', layer_
return None
def get_image_saliency_results(model, raw_image, input, label,
device=None,
methods=['lime_imagenet', 'gradcam', 'smooth_grad',
'guided_backprop', 'deeplift', 'grad_x_input'],
layer_path=None):
results = []
for method in methods:
sal = get_saliency(model, raw_image, input, label, method=method)
sal = get_saliency(model, raw_image, input, label, device=device, method=method)
if sal is not None:
results.append(ImageSaliencyResult(raw_image, sal, method))
@ -116,4 +117,4 @@ def get_image_saliency_plot(image_saliency_results, cols:int=2, figsize:tuple=No
image_utils.show_image(r.raw_image, img2=saliency_upsampled,
alpha2=r.saliency_alpha, cmap2=r.saliency_cmap, ax=ax)
return figure
return figure