зеркало из https://github.com/microsoft/hat.git
PR comments
This commit is contained in:
Родитель
2df22de71b
Коммит
e3d0834be1
|
@ -8,10 +8,10 @@ class CallableFunc(ABC):
|
|||
try:
|
||||
self.init_runtime(benchmark=False, device_id=device_id, working_dir=working_dir)
|
||||
try:
|
||||
self.init_main(benchmark=False, args=args, device_id=device_id)
|
||||
self.init_batch(benchmark=False, args=args, device_id=device_id)
|
||||
timing = self.run_batch(benchmark=False, iters=1, args=args)
|
||||
finally:
|
||||
self.cleanup_main(benchmark=False, args=args)
|
||||
self.cleanup_batch(benchmark=False, args=args)
|
||||
finally:
|
||||
self.cleanup_runtime(benchmark=False, working_dir=working_dir)
|
||||
|
||||
|
@ -21,7 +21,7 @@ class CallableFunc(ABC):
|
|||
try:
|
||||
self.init_runtime(benchmark=True, device_id=device_id, working_dir=working_dir)
|
||||
try:
|
||||
self.init_main(benchmark=True, warmup_iters=warmup_iters, args=args, device_id=device_id)
|
||||
self.init_batch(benchmark=True, warmup_iters=warmup_iters, args=args, device_id=device_id)
|
||||
|
||||
# Run multiple batches
|
||||
batch_timings_ms: List[float] = []
|
||||
|
@ -37,7 +37,7 @@ class CallableFunc(ABC):
|
|||
|
||||
mean_elapsed_time_ms = sum(batch_timings_ms) / iterations
|
||||
finally:
|
||||
self.cleanup_main(benchmark=True, args=args)
|
||||
self.cleanup_batch(benchmark=True, args=args)
|
||||
finally:
|
||||
self.cleanup_runtime(benchmark=True, working_dir=working_dir)
|
||||
return mean_elapsed_time_ms, batch_timings_ms
|
||||
|
@ -47,7 +47,7 @@ class CallableFunc(ABC):
|
|||
...
|
||||
|
||||
@abstractmethod
|
||||
def init_main(self, benchmark: bool, warmup_iters=0, device_id: int=0, *args):
|
||||
def init_batch(self, benchmark: bool, warmup_iters=0, device_id: int=0, *args):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
|
@ -55,7 +55,7 @@ class CallableFunc(ABC):
|
|||
...
|
||||
|
||||
@abstractmethod
|
||||
def cleanup_main(self, benchmark: bool, *args):
|
||||
def cleanup_batch(self, benchmark: bool, *args):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
@ -204,7 +204,7 @@ class CudaCallableFunc(CallableFunc):
|
|||
def cleanup_runtime(self, benchmark: bool, working_dir: str):
|
||||
cuda.cuCtxDestroy(self.context)
|
||||
|
||||
def init_main(self, benchmark: bool, warmup_iters=0, device_id: int = 0, args=[]):
|
||||
def init_batch(self, benchmark: bool, warmup_iters=0, device_id: int = 0, args=[]):
|
||||
self.func_info.verify(args[0] if benchmark else args)
|
||||
self.device_mem = allocate_cuda_mem(self.func_info.arguments)
|
||||
|
||||
|
@ -265,7 +265,7 @@ class CudaCallableFunc(CallableFunc):
|
|||
|
||||
return batch_time_ms
|
||||
|
||||
def cleanup_main(self, benchmark: bool, args=[]):
|
||||
def cleanup_batch(self, benchmark: bool, args=[]):
|
||||
# If there's no device mem, that means allocation during initialization failed, which means nothing else needs to be cleaned up either
|
||||
if not benchmark and self.device_mem:
|
||||
transfer_mem_cuda_to_host(device_args=self.device_mem, host_args=args, arg_infos=self.func_info.arguments)
|
||||
|
|
|
@ -110,7 +110,7 @@ class HostCallableFunc(CallableFunc):
|
|||
os.remove(target_file)
|
||||
|
||||
|
||||
def init_main(self, benchmark: bool, warmup_iters=0, device_id: int = 0, args=[]):
|
||||
def init_batch(self, benchmark: bool, warmup_iters=0, device_id: int = 0, args=[]):
|
||||
self.func_info.verify(args[0] if benchmark else args)
|
||||
|
||||
for _ in range(warmup_iters):
|
||||
|
@ -130,7 +130,7 @@ class HostCallableFunc(CallableFunc):
|
|||
|
||||
return float(self.timing_arg_val.value)
|
||||
|
||||
def cleanup_main(self, benchmark: bool, args=[]):
|
||||
def cleanup_batch(self, benchmark: bool, args=[]):
|
||||
pass
|
||||
|
||||
def should_flush_cache(self) -> bool:
|
||||
|
|
|
@ -122,7 +122,7 @@ class RocmCallableFunc(CallableFunc):
|
|||
def cleanup_runtime(self, benchmark: bool, working_dir: str):
|
||||
pass
|
||||
|
||||
def init_main(self, benchmark: bool, warmup_iters=0, device_id: int = 0, args=[]):
|
||||
def init_batch(self, benchmark: bool, warmup_iters=0, device_id: int = 0, args=[]):
|
||||
self.func_info.verify(args[0] if benchmark else args)
|
||||
self.device_mem = allocate_rocm_mem(benchmark, self.func_info.arguments, device_id)
|
||||
|
||||
|
@ -167,7 +167,7 @@ class RocmCallableFunc(CallableFunc):
|
|||
|
||||
return batch_time_ms
|
||||
|
||||
def cleanup_main(self, benchmark: bool, args=[]):
|
||||
def cleanup_batch(self, benchmark: bool, args=[]):
|
||||
# If there's no device mem, that means allocation during initialization failed, which means nothing else needs to be cleaned up either
|
||||
if not benchmark and self.device_mem:
|
||||
transfer_mem_rocm_to_host(device_args=self.device_mem, host_args=args, arg_infos=self.func_info.arguments)
|
||||
|
|
Загрузка…
Ссылка в новой задаче