diff --git a/tests/test_jobwatcher.py b/tests/test_jobwatcher.py index 21b4fa5..e092f20 100644 --- a/tests/test_jobwatcher.py +++ b/tests/test_jobwatcher.py @@ -44,24 +44,27 @@ class TestJobWatcher(unittest.TestCase): client._check_valid_dir(1) self.assertEqual(client._check_valid_dir(CWD), CWD) - @mock.patch.object(client, 'storage_client') - @mock.patch.object(client, '_download_output') - def test_watcher_track_completed_outputs(self, mock_download, mock_storage): + @mock.patch.object(client, 'batch_client') + def test_watcher_track_completed_outputs(self, mock_batchClient): blob = namedtuple("Blob", "name, properties") - mock_storage.list_blobs.return_value = [ - blob("job_output.exr", mock.Mock(content_length=1)), - blob("subdir/job_output.png", mock.Mock(content_length=1)), - blob("thumbs/0.png", mock.Mock(content_length=1)), - blob("logs/frame_0.log", mock.Mock(content_length=1)), - blob("logs/frame_0_error.log", mock.Mock(content_length=1))] + + mock_batchClient.file.list_from_group.return_value = ({ + 'name': b.name, + 'size': b.properties.content_length} + for b in [ + blob("job_output.exr", mock.Mock(content_length=1)), + blob("subdir/job_output.png", mock.Mock(content_length=1)), + blob("thumbs/0.png", mock.Mock(content_length=1)), + blob("logs/frame_0.log", mock.Mock(content_length=1)), + blob("logs/frame_0_error.log", mock.Mock(content_length=1))]) client._track_completed_outputs("container", "\\test_dir") - mock_storage.list_blobs.assert_called_with("container") - self.assertEqual(mock_download.call_count, 4) - mock_download.assert_any_call('container', 'job_output.exr', '\\test_dir\\job_output.exr', 1) - mock_download.assert_any_call('container', 'subdir/job_output.png', '\\test_dir\\subdir\\job_output.png', 1) - mock_download.assert_any_call('container', 'logs/frame_0.log', '\\test_dir\\logs\\frame_0.log', 1) - mock_download.assert_any_call('container', 'logs/frame_0_error.log', '\\test_dir\\logs\\frame_0_error.log', 1) + + self.assertEqual(mock_batchClient.file.download.call_count, 4) + mock_batchClient.file.download.assert_any_call('\\test_dir', 'container', remote_path='job_output.exr') + mock_batchClient.file.download.assert_any_call('\\test_dir', 'container', remote_path='subdir/job_output.png') + mock_batchClient.file.download.assert_any_call('\\test_dir', 'container', remote_path='logs/frame_0.log') + mock_batchClient.file.download.assert_any_call('\\test_dir', 'container', remote_path='logs/frame_0_error.log') def test_watcher_check_job_stopped(self): mock_job = mock.create_autospec(models.CloudJob)