This commit is contained in:
Ritwik Das 2023-04-11 20:38:17 -07:00
Родитель 2df22de71b
Коммит e3d0834be1
4 изменённых файлов: 12 добавлений и 12 удалений

Просмотреть файл

@ -8,10 +8,10 @@ class CallableFunc(ABC):
try:
self.init_runtime(benchmark=False, device_id=device_id, working_dir=working_dir)
try:
self.init_main(benchmark=False, args=args, device_id=device_id)
self.init_batch(benchmark=False, args=args, device_id=device_id)
timing = self.run_batch(benchmark=False, iters=1, args=args)
finally:
self.cleanup_main(benchmark=False, args=args)
self.cleanup_batch(benchmark=False, args=args)
finally:
self.cleanup_runtime(benchmark=False, working_dir=working_dir)
@ -21,7 +21,7 @@ class CallableFunc(ABC):
try:
self.init_runtime(benchmark=True, device_id=device_id, working_dir=working_dir)
try:
self.init_main(benchmark=True, warmup_iters=warmup_iters, args=args, device_id=device_id)
self.init_batch(benchmark=True, warmup_iters=warmup_iters, args=args, device_id=device_id)
# Run multiple batches
batch_timings_ms: List[float] = []
@ -37,7 +37,7 @@ class CallableFunc(ABC):
mean_elapsed_time_ms = sum(batch_timings_ms) / iterations
finally:
self.cleanup_main(benchmark=True, args=args)
self.cleanup_batch(benchmark=True, args=args)
finally:
self.cleanup_runtime(benchmark=True, working_dir=working_dir)
return mean_elapsed_time_ms, batch_timings_ms
@ -47,7 +47,7 @@ class CallableFunc(ABC):
...
@abstractmethod
def init_main(self, benchmark: bool, warmup_iters=0, device_id: int=0, *args):
def init_batch(self, benchmark: bool, warmup_iters=0, device_id: int=0, *args):
...
@abstractmethod
@ -55,7 +55,7 @@ class CallableFunc(ABC):
...
@abstractmethod
def cleanup_main(self, benchmark: bool, *args):
def cleanup_batch(self, benchmark: bool, *args):
...
@abstractmethod

Просмотреть файл

@ -204,7 +204,7 @@ class CudaCallableFunc(CallableFunc):
def cleanup_runtime(self, benchmark: bool, working_dir: str):
cuda.cuCtxDestroy(self.context)
def init_main(self, benchmark: bool, warmup_iters=0, device_id: int = 0, args=[]):
def init_batch(self, benchmark: bool, warmup_iters=0, device_id: int = 0, args=[]):
self.func_info.verify(args[0] if benchmark else args)
self.device_mem = allocate_cuda_mem(self.func_info.arguments)
@ -265,7 +265,7 @@ class CudaCallableFunc(CallableFunc):
return batch_time_ms
def cleanup_main(self, benchmark: bool, args=[]):
def cleanup_batch(self, benchmark: bool, args=[]):
# If there's no device mem, that means allocation during initialization failed, which means nothing else needs to be cleaned up either
if not benchmark and self.device_mem:
transfer_mem_cuda_to_host(device_args=self.device_mem, host_args=args, arg_infos=self.func_info.arguments)

Просмотреть файл

@ -110,7 +110,7 @@ class HostCallableFunc(CallableFunc):
os.remove(target_file)
def init_main(self, benchmark: bool, warmup_iters=0, device_id: int = 0, args=[]):
def init_batch(self, benchmark: bool, warmup_iters=0, device_id: int = 0, args=[]):
self.func_info.verify(args[0] if benchmark else args)
for _ in range(warmup_iters):
@ -130,7 +130,7 @@ class HostCallableFunc(CallableFunc):
return float(self.timing_arg_val.value)
def cleanup_main(self, benchmark: bool, args=[]):
def cleanup_batch(self, benchmark: bool, args=[]):
pass
def should_flush_cache(self) -> bool:

Просмотреть файл

@ -122,7 +122,7 @@ class RocmCallableFunc(CallableFunc):
def cleanup_runtime(self, benchmark: bool, working_dir: str):
pass
def init_main(self, benchmark: bool, warmup_iters=0, device_id: int = 0, args=[]):
def init_batch(self, benchmark: bool, warmup_iters=0, device_id: int = 0, args=[]):
self.func_info.verify(args[0] if benchmark else args)
self.device_mem = allocate_rocm_mem(benchmark, self.func_info.arguments, device_id)
@ -167,7 +167,7 @@ class RocmCallableFunc(CallableFunc):
return batch_time_ms
def cleanup_main(self, benchmark: bool, args=[]):
def cleanup_batch(self, benchmark: bool, args=[]):
# If there's no device mem, that means allocation during initialization failed, which means nothing else needs to be cleaned up either
if not benchmark and self.device_mem:
transfer_mem_rocm_to_host(device_args=self.device_mem, host_args=args, arg_infos=self.func_info.arguments)