diff --git a/torchsat/datasets/nwpu.py b/torchsat/datasets/nwpu.py index dab51aa86..44769f4ea 100644 --- a/torchsat/datasets/nwpu.py +++ b/torchsat/datasets/nwpu.py @@ -31,6 +31,8 @@ class VHR10(VisionDataset): def __init__( self, root: str, + transform: Optional[Callable[[Any], Any]] = None, + target_transform: Optional[Callable[[Any], Any]] = None, transforms: Optional[Callable[[Any], Any]] = None, download: bool = False, ) -> None: @@ -38,11 +40,15 @@ class VHR10(VisionDataset): Parameters: root: root directory where dataset can be found + transform: a function/transform that takes in a PIL image and returns a + transformed version + target_transform: a function/transform that takes in the target and + transforms it transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory """ - super().__init__(root) + super().__init__(root, transforms, transform, target_transform) if download: self.download()