Docs for saliency, Credits section in Readme, gradcam ignored for unknown model types
This commit is contained in:
Родитель
73f42f0a81
Коммит
c99f55c38e
|
@ -117,6 +117,10 @@ We would love your contributions, feedback, questions, and feature requests! Ple
|
|||
|
||||
Join the TensorWatch group on [Facebook](https://www.facebook.com/groups/378075159472803/) to stay up to date or ask any questions.
|
||||
|
||||
## Credits
|
||||
|
||||
TensorWatch utilizes several open source libraries for many of its features. These includes: [hiddenlayer](https://github.com/waleedka/hiddenlayer), [torchstat](https://github.com/Swall0w/torchstat), [Visual-Attribution](https://github.com/yulongwang12/visual-attribution), [pyzmq](https://github.com/zeromq/pyzmq), [receptivefield](https://github.com/fornaxai/receptivefield), [nbformat](https://github.com/jupyter/nbformat). Please see `install_requires` section in [setup.py](setup.py) for upto date list.
|
||||
|
||||
## License
|
||||
|
||||
This project is released under the MIT License. Please review the [License file](LICENSE.txt) for more details.
|
||||
|
|
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
2
setup.py
2
setup.py
|
@ -23,6 +23,6 @@ setuptools.setup(
|
|||
"Operating System :: OS Independent",
|
||||
),
|
||||
install_requires=[
|
||||
'matplotlib', 'numpy', 'pyzmq', 'plotly', 'torchstat', 'receptivefield', 'ipywidgets', 'sklearn'
|
||||
'matplotlib', 'numpy', 'pyzmq', 'plotly', 'torchstat', 'receptivefield', 'ipywidgets', 'sklearn', 'nbformat'
|
||||
]
|
||||
)
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
from torchvision import transforms
|
||||
from . import pytorch_utils
|
||||
import json
|
||||
import json, os
|
||||
|
||||
def get_image_transform():
|
||||
transf = transforms.Compose([ #TODO: cache these transforms?
|
||||
|
@ -21,6 +21,9 @@ def get_normalize_transform():
|
|||
return transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
|
||||
def image2batch(image):
|
||||
return pytorch_utils.image2batch(image, image_transform=get_image_transform())
|
||||
|
||||
def predict(model, images, image_transform=None, device=None):
|
||||
logits = pytorch_utils.batch_predict(model, images,
|
||||
input_transform=image_transform or get_image_transform(), device=device)
|
||||
|
@ -39,15 +42,18 @@ def probabilities2classes(probs, topk=5):
|
|||
top_probs = probs.topk(topk)
|
||||
# return (probability, class_id, class_label, class_code)
|
||||
return tuple((p,c, labels.index2label_text(c), labels.index2label_code(c)) \
|
||||
for p, c in zip(top_probs[0][0].detach().numpy(), top_probs[1][0].detach().numpy()))
|
||||
for p, c in zip(top_probs[0][0].data.cpu().numpy(), top_probs[1][0].data.cpu().numpy()))
|
||||
|
||||
class ImagenetLabels:
|
||||
def __init__(self, json_path='imagenet_class_index.json'):
|
||||
def __init__(self, json_path=None):
|
||||
self._idx2label = []
|
||||
self._idx2cls = []
|
||||
self._cls2label = {}
|
||||
self._cls2idx = {}
|
||||
with open(json_path, "r") as read_file:
|
||||
|
||||
json_path = json_path or os.path.join(os.path.dirname(__file__), 'imagenet_class_index.json')
|
||||
|
||||
with open(os.path.abspath(json_path), "r") as read_file:
|
||||
class_json = json.load(read_file)
|
||||
self._idx2label = [class_json[str(k)][1] for k in range(len(class_json))]
|
||||
self._idx2cls = [class_json[str(k)][0] for k in range(len(class_json))]
|
||||
|
|
|
@ -21,17 +21,20 @@ def tensors2batch(tensors, preprocess_transform=None):
|
|||
def int2tensor(val):
|
||||
return torch.LongTensor([val])
|
||||
|
||||
def image2batch(image, image_transform=None):
|
||||
if image_transform:
|
||||
input_x = image_transform(image)
|
||||
else: # if no transforms supplied then just convert PIL image to tensor
|
||||
input_x = transforms.ToTensor()(image)
|
||||
input_x = input_x.unsqueeze(0) #convert to batch of 1
|
||||
return input_x
|
||||
|
||||
def image_class2tensor(image_path, class_index=None, image_convert_mode=None,
|
||||
image_transform=None):
|
||||
|
||||
raw_input = image_utils.open_image(os.path.abspath(image_path), convert_mode=image_convert_mode)
|
||||
if image_transform:
|
||||
input_x = image_transform(raw_input)
|
||||
else:
|
||||
input_x = transforms.ToTensor()(raw_input)
|
||||
input_x = input_x.unsqueeze(0) #convert to batch of 1
|
||||
image_pil = image_utils.open_image(os.path.abspath(image_path), convert_mode=image_convert_mode)
|
||||
input_x = image2batch(image_pil, image_transform)
|
||||
target_class = int2tensor(class_index) if class_index is not None else None
|
||||
return raw_input, input_x, target_class
|
||||
return image_pil, input_x, target_class
|
||||
|
||||
def batch_predict(model, inputs, input_transform=None, device=None):
|
||||
if input_transform:
|
||||
|
|
|
@ -3,6 +3,9 @@ from .backprop import VanillaGradExplainer
|
|||
|
||||
|
||||
def _get_layer(model, key_list):
|
||||
if key_list is None:
|
||||
return None
|
||||
|
||||
a = model
|
||||
for key in key_list:
|
||||
a = a._modules[key]
|
||||
|
@ -27,12 +30,13 @@ class GradCAMExplainer(VanillaGradExplainer):
|
|||
def backward_hook(m, grad_i, grad_o):
|
||||
self.intermediate_grad.append(grad_o[0].data.clone())
|
||||
|
||||
if self.use_inp:
|
||||
self.target_layer.register_forward_hook(forward_hook_input)
|
||||
else:
|
||||
self.target_layer.register_forward_hook(forward_hook_output)
|
||||
if self.target_layer is not None:
|
||||
if self.use_inp:
|
||||
self.target_layer.register_forward_hook(forward_hook_input)
|
||||
else:
|
||||
self.target_layer.register_forward_hook(forward_hook_output)
|
||||
|
||||
self.target_layer.register_backward_hook(backward_hook)
|
||||
self.target_layer.register_backward_hook(backward_hook)
|
||||
|
||||
def _reset_intermediate_lists(self):
|
||||
self.intermediate_act = []
|
||||
|
@ -43,14 +47,16 @@ class GradCAMExplainer(VanillaGradExplainer):
|
|||
|
||||
_ = super(GradCAMExplainer, self)._backprop(inp, ind)
|
||||
|
||||
grad = self.intermediate_grad[0]
|
||||
act = self.intermediate_act[0]
|
||||
if len(self.intermediate_grad):
|
||||
grad = self.intermediate_grad[0]
|
||||
act = self.intermediate_act[0]
|
||||
|
||||
weights = grad.sum(-1).sum(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
cam = weights * act
|
||||
cam = cam.sum(1).unsqueeze(1)
|
||||
weights = grad.sum(-1).sum(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
cam = weights * act
|
||||
cam = cam.sum(1).unsqueeze(1)
|
||||
|
||||
cam = torch.clamp(cam, min=0)
|
||||
|
||||
return cam
|
||||
cam = torch.clamp(cam, min=0)
|
||||
|
||||
return cam
|
||||
else:
|
||||
return None
|
||||
|
|
|
@ -52,7 +52,7 @@ def _get_layer_path(model):
|
|||
return ['avgpool'] #layer4
|
||||
elif model.__class__.__name__ == 'Inception3':
|
||||
return ['Mixed_7c', 'branch_pool'] # ['conv2d_94'], 'mixed10'
|
||||
else: #unknown network
|
||||
else: #TODO: guess layer for other networks?
|
||||
return None
|
||||
|
||||
def get_saliency(model, raw_input, input, label, method='integrate_grad', layer_path=None):
|
||||
|
@ -75,11 +75,14 @@ def get_saliency(model, raw_input, input, label, method='integrate_grad', layer_
|
|||
exp = _get_explainer(method, model, layer_path)
|
||||
saliency = exp.explain(input, label, raw_input)
|
||||
|
||||
saliency = saliency.abs().sum(dim=1)[0].squeeze()
|
||||
saliency -= saliency.min()
|
||||
saliency /= (saliency.max() + 1e-20)
|
||||
if saliency is not None:
|
||||
saliency = saliency.abs().sum(dim=1)[0].squeeze()
|
||||
saliency -= saliency.min()
|
||||
saliency /= (saliency.max() + 1e-20)
|
||||
|
||||
return saliency.detach().cpu().numpy()
|
||||
return saliency.detach().cpu().numpy()
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_image_saliency_results(model, raw_image, input, label,
|
||||
methods=['lime_imagenet', 'gradcam', 'smooth_grad',
|
||||
|
@ -88,15 +91,17 @@ def get_image_saliency_results(model, raw_image, input, label,
|
|||
results = []
|
||||
for method in methods:
|
||||
sal = get_saliency(model, raw_image, input, label, method=method)
|
||||
results.append(ImageSaliencyResult(raw_image, sal, method))
|
||||
|
||||
if sal is not None:
|
||||
results.append(ImageSaliencyResult(raw_image, sal, method))
|
||||
return results
|
||||
|
||||
def get_image_saliency_plot(image_saliency_results, cols = 2, figsize = None):
|
||||
def get_image_saliency_plot(image_saliency_results, cols:int=2, figsize:tuple=None):
|
||||
import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue
|
||||
|
||||
rows = math.ceil(len(image_saliency_results) / cols)
|
||||
figsize=figsize or (8, 3 * rows)
|
||||
figure = plt.figure(figsize=figsize) #figsize=(8, 3)
|
||||
figure = plt.figure(figsize=figsize)
|
||||
|
||||
for i, r in enumerate(image_saliency_results):
|
||||
ax = figure.add_subplot(rows, cols, i+1)
|
||||
|
@ -106,7 +111,8 @@ def get_image_saliency_plot(image_saliency_results, cols = 2, figsize = None):
|
|||
|
||||
#upsampler = nn.Upsample(size=(raw_image.height, raw_image.width), mode='bilinear')
|
||||
saliency_upsampled = skimage.transform.resize(r.saliency,
|
||||
(r.raw_image.height, r.raw_image.width))
|
||||
(r.raw_image.height, r.raw_image.width),
|
||||
mode='reflect')
|
||||
|
||||
image_utils.show_image(r.raw_image, img2=saliency_upsampled,
|
||||
alpha2=r.saliency_alpha, cmap2=r.saliency_cmap, ax=ax)
|
||||
|
|
|
@ -47,7 +47,7 @@ class WatcherClient(WatcherBase):
|
|||
utils.debug_log("WatcherClient is closed", verbosity=1)
|
||||
super(WatcherClient, self).close()
|
||||
|
||||
def devices_or_default(self, devices:Sequence[str])->Sequence[str]: # overriden
|
||||
def devices_or_default(self, devices:Sequence[str])->Sequence[str]: # overridden
|
||||
# TODO: this method is duplicated in Watcher and WatcherClient
|
||||
|
||||
# make sure TCP port is attached to tcp device
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
<ProjectGuid>9a7fe67e-93f0-42b5-b58f-77320fc639e4</ProjectGuid>
|
||||
<ProjectHome>
|
||||
</ProjectHome>
|
||||
<StartupFile>simple_log\sum_lazy.py</StartupFile>
|
||||
<StartupFile>post_train\saliency.py</StartupFile>
|
||||
<SearchPath>
|
||||
</SearchPath>
|
||||
<WorkingDirectory>.</WorkingDirectory>
|
||||
|
|
Загрузка…
Ссылка в новой задаче