From a25e0e22384f1fbc74fbe0a7c09eea231d81044b Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Wed, 12 May 2021 17:48:53 -0500 Subject: [PATCH] Allow more transforms --- torchsat/datasets/nwpu.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/torchsat/datasets/nwpu.py b/torchsat/datasets/nwpu.py index dab51aa86..44769f4ea 100644 --- a/torchsat/datasets/nwpu.py +++ b/torchsat/datasets/nwpu.py @@ -31,6 +31,8 @@ class VHR10(VisionDataset): def __init__( self, root: str, + transform: Optional[Callable[[Any], Any]] = None, + target_transform: Optional[Callable[[Any], Any]] = None, transforms: Optional[Callable[[Any], Any]] = None, download: bool = False, ) -> None: @@ -38,11 +40,15 @@ class VHR10(VisionDataset): Parameters: root: root directory where dataset can be found + transform: a function/transform that takes in a PIL image and returns a + transformed version + target_transform: a function/transform that takes in the target and + transforms it transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory """ - super().__init__(root) + super().__init__(root, transforms, transform, target_transform) if download: self.download()