Merge pull request #77 from kikusui6192/Update_saliency.py
Added arg device in get_saliency and get_image_saliency_results
This commit is contained in:
Коммит
c9e5e9c0c4
|
@ -55,8 +55,9 @@ def _get_layer_path(model):
|
|||
else: #TODO: guess layer for other networks?
|
||||
return None
|
||||
|
||||
def get_saliency(model, raw_input, input, label, method='integrate_grad', layer_path=None):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
def get_saliency(model, raw_input, input, label, device=None, method='integrate_grad', layer_path=None):
|
||||
if device == None or type(device) != torch.device:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
model.to(device)
|
||||
input = input.to(device)
|
||||
|
@ -71,7 +72,6 @@ def get_saliency(model, raw_input, input, label, method='integrate_grad', layer_
|
|||
model.zero_grad()
|
||||
|
||||
layer_path = layer_path or _get_layer_path(model)
|
||||
|
||||
exp = _get_explainer(method, model, layer_path)
|
||||
saliency = exp.explain(input, label, raw_input)
|
||||
|
||||
|
@ -85,12 +85,13 @@ def get_saliency(model, raw_input, input, label, method='integrate_grad', layer_
|
|||
return None
|
||||
|
||||
def get_image_saliency_results(model, raw_image, input, label,
|
||||
device=None,
|
||||
methods=['lime_imagenet', 'gradcam', 'smooth_grad',
|
||||
'guided_backprop', 'deeplift', 'grad_x_input'],
|
||||
layer_path=None):
|
||||
results = []
|
||||
for method in methods:
|
||||
sal = get_saliency(model, raw_image, input, label, method=method)
|
||||
sal = get_saliency(model, raw_image, input, label, device=device, method=method)
|
||||
|
||||
if sal is not None:
|
||||
results.append(ImageSaliencyResult(raw_image, sal, method))
|
||||
|
@ -116,4 +117,4 @@ def get_image_saliency_plot(image_saliency_results, cols:int=2, figsize:tuple=No
|
|||
|
||||
image_utils.show_image(r.raw_image, img2=saliency_upsampled,
|
||||
alpha2=r.saliency_alpha, cmap2=r.saliency_cmap, ax=ax)
|
||||
return figure
|
||||
return figure
|
Загрузка…
Ссылка в новой задаче