Docs for saliency, Credits section in Readme, gradcam ignored for unknown model types

This commit is contained in:
Shital Shah 2019-05-28 01:00:32 -07:00
Родитель 73f42f0a81
Коммит c99f55c38e
9 изменённых файлов: 192 добавлений и 56 удалений

Просмотреть файл

@ -117,6 +117,10 @@ We would love your contributions, feedback, questions, and feature requests! Ple
Join the TensorWatch group on [Facebook](https://www.facebook.com/groups/378075159472803/) to stay up to date or ask any questions.
## Credits
TensorWatch utilizes several open source libraries for many of its features. These includes: [hiddenlayer](https://github.com/waleedka/hiddenlayer), [torchstat](https://github.com/Swall0w/torchstat), [Visual-Attribution](https://github.com/yulongwang12/visual-attribution), [pyzmq](https://github.com/zeromq/pyzmq), [receptivefield](https://github.com/fornaxai/receptivefield), [nbformat](https://github.com/jupyter/nbformat). Please see `install_requires` section in [setup.py](setup.py) for upto date list.
## License
This project is released under the MIT License. Please review the [License file](LICENSE.txt) for more details.

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -23,6 +23,6 @@ setuptools.setup(
"Operating System :: OS Independent",
),
install_requires=[
'matplotlib', 'numpy', 'pyzmq', 'plotly', 'torchstat', 'receptivefield', 'ipywidgets', 'sklearn'
'matplotlib', 'numpy', 'pyzmq', 'plotly', 'torchstat', 'receptivefield', 'ipywidgets', 'sklearn', 'nbformat'
]
)

Просмотреть файл

@ -3,7 +3,7 @@
from torchvision import transforms
from . import pytorch_utils
import json
import json, os
def get_image_transform():
transf = transforms.Compose([ #TODO: cache these transforms?
@ -21,6 +21,9 @@ def get_normalize_transform():
return transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
def image2batch(image):
return pytorch_utils.image2batch(image, image_transform=get_image_transform())
def predict(model, images, image_transform=None, device=None):
logits = pytorch_utils.batch_predict(model, images,
input_transform=image_transform or get_image_transform(), device=device)
@ -39,15 +42,18 @@ def probabilities2classes(probs, topk=5):
top_probs = probs.topk(topk)
# return (probability, class_id, class_label, class_code)
return tuple((p,c, labels.index2label_text(c), labels.index2label_code(c)) \
for p, c in zip(top_probs[0][0].detach().numpy(), top_probs[1][0].detach().numpy()))
for p, c in zip(top_probs[0][0].data.cpu().numpy(), top_probs[1][0].data.cpu().numpy()))
class ImagenetLabels:
def __init__(self, json_path='imagenet_class_index.json'):
def __init__(self, json_path=None):
self._idx2label = []
self._idx2cls = []
self._cls2label = {}
self._cls2idx = {}
with open(json_path, "r") as read_file:
json_path = json_path or os.path.join(os.path.dirname(__file__), 'imagenet_class_index.json')
with open(os.path.abspath(json_path), "r") as read_file:
class_json = json.load(read_file)
self._idx2label = [class_json[str(k)][1] for k in range(len(class_json))]
self._idx2cls = [class_json[str(k)][0] for k in range(len(class_json))]

Просмотреть файл

@ -21,17 +21,20 @@ def tensors2batch(tensors, preprocess_transform=None):
def int2tensor(val):
return torch.LongTensor([val])
def image2batch(image, image_transform=None):
if image_transform:
input_x = image_transform(image)
else: # if no transforms supplied then just convert PIL image to tensor
input_x = transforms.ToTensor()(image)
input_x = input_x.unsqueeze(0) #convert to batch of 1
return input_x
def image_class2tensor(image_path, class_index=None, image_convert_mode=None,
image_transform=None):
raw_input = image_utils.open_image(os.path.abspath(image_path), convert_mode=image_convert_mode)
if image_transform:
input_x = image_transform(raw_input)
else:
input_x = transforms.ToTensor()(raw_input)
input_x = input_x.unsqueeze(0) #convert to batch of 1
image_pil = image_utils.open_image(os.path.abspath(image_path), convert_mode=image_convert_mode)
input_x = image2batch(image_pil, image_transform)
target_class = int2tensor(class_index) if class_index is not None else None
return raw_input, input_x, target_class
return image_pil, input_x, target_class
def batch_predict(model, inputs, input_transform=None, device=None):
if input_transform:

Просмотреть файл

@ -3,6 +3,9 @@ from .backprop import VanillaGradExplainer
def _get_layer(model, key_list):
if key_list is None:
return None
a = model
for key in key_list:
a = a._modules[key]
@ -27,12 +30,13 @@ class GradCAMExplainer(VanillaGradExplainer):
def backward_hook(m, grad_i, grad_o):
self.intermediate_grad.append(grad_o[0].data.clone())
if self.use_inp:
self.target_layer.register_forward_hook(forward_hook_input)
else:
self.target_layer.register_forward_hook(forward_hook_output)
if self.target_layer is not None:
if self.use_inp:
self.target_layer.register_forward_hook(forward_hook_input)
else:
self.target_layer.register_forward_hook(forward_hook_output)
self.target_layer.register_backward_hook(backward_hook)
self.target_layer.register_backward_hook(backward_hook)
def _reset_intermediate_lists(self):
self.intermediate_act = []
@ -43,14 +47,16 @@ class GradCAMExplainer(VanillaGradExplainer):
_ = super(GradCAMExplainer, self)._backprop(inp, ind)
grad = self.intermediate_grad[0]
act = self.intermediate_act[0]
if len(self.intermediate_grad):
grad = self.intermediate_grad[0]
act = self.intermediate_act[0]
weights = grad.sum(-1).sum(-1).unsqueeze(-1).unsqueeze(-1)
cam = weights * act
cam = cam.sum(1).unsqueeze(1)
weights = grad.sum(-1).sum(-1).unsqueeze(-1).unsqueeze(-1)
cam = weights * act
cam = cam.sum(1).unsqueeze(1)
cam = torch.clamp(cam, min=0)
return cam
cam = torch.clamp(cam, min=0)
return cam
else:
return None

Просмотреть файл

@ -52,7 +52,7 @@ def _get_layer_path(model):
return ['avgpool'] #layer4
elif model.__class__.__name__ == 'Inception3':
return ['Mixed_7c', 'branch_pool'] # ['conv2d_94'], 'mixed10'
else: #unknown network
else: #TODO: guess layer for other networks?
return None
def get_saliency(model, raw_input, input, label, method='integrate_grad', layer_path=None):
@ -75,11 +75,14 @@ def get_saliency(model, raw_input, input, label, method='integrate_grad', layer_
exp = _get_explainer(method, model, layer_path)
saliency = exp.explain(input, label, raw_input)
saliency = saliency.abs().sum(dim=1)[0].squeeze()
saliency -= saliency.min()
saliency /= (saliency.max() + 1e-20)
if saliency is not None:
saliency = saliency.abs().sum(dim=1)[0].squeeze()
saliency -= saliency.min()
saliency /= (saliency.max() + 1e-20)
return saliency.detach().cpu().numpy()
return saliency.detach().cpu().numpy()
else:
return None
def get_image_saliency_results(model, raw_image, input, label,
methods=['lime_imagenet', 'gradcam', 'smooth_grad',
@ -88,15 +91,17 @@ def get_image_saliency_results(model, raw_image, input, label,
results = []
for method in methods:
sal = get_saliency(model, raw_image, input, label, method=method)
results.append(ImageSaliencyResult(raw_image, sal, method))
if sal is not None:
results.append(ImageSaliencyResult(raw_image, sal, method))
return results
def get_image_saliency_plot(image_saliency_results, cols = 2, figsize = None):
def get_image_saliency_plot(image_saliency_results, cols:int=2, figsize:tuple=None):
import matplotlib.pyplot as plt # delayed import due to matplotlib threading issue
rows = math.ceil(len(image_saliency_results) / cols)
figsize=figsize or (8, 3 * rows)
figure = plt.figure(figsize=figsize) #figsize=(8, 3)
figure = plt.figure(figsize=figsize)
for i, r in enumerate(image_saliency_results):
ax = figure.add_subplot(rows, cols, i+1)
@ -106,7 +111,8 @@ def get_image_saliency_plot(image_saliency_results, cols = 2, figsize = None):
#upsampler = nn.Upsample(size=(raw_image.height, raw_image.width), mode='bilinear')
saliency_upsampled = skimage.transform.resize(r.saliency,
(r.raw_image.height, r.raw_image.width))
(r.raw_image.height, r.raw_image.width),
mode='reflect')
image_utils.show_image(r.raw_image, img2=saliency_upsampled,
alpha2=r.saliency_alpha, cmap2=r.saliency_cmap, ax=ax)

Просмотреть файл

@ -47,7 +47,7 @@ class WatcherClient(WatcherBase):
utils.debug_log("WatcherClient is closed", verbosity=1)
super(WatcherClient, self).close()
def devices_or_default(self, devices:Sequence[str])->Sequence[str]: # overriden
def devices_or_default(self, devices:Sequence[str])->Sequence[str]: # overridden
# TODO: this method is duplicated in Watcher and WatcherClient
# make sure TCP port is attached to tcp device

Просмотреть файл

@ -5,7 +5,7 @@
<ProjectGuid>9a7fe67e-93f0-42b5-b58f-77320fc639e4</ProjectGuid>
<ProjectHome>
</ProjectHome>
<StartupFile>simple_log\sum_lazy.py</StartupFile>
<StartupFile>post_train\saliency.py</StartupFile>
<SearchPath>
</SearchPath>
<WorkingDirectory>.</WorkingDirectory>