pytorch hook
고등학생 때 DLL injection으로 피카츄 배구 해킹같은걸 했었는데, 그 때 사용했던 기법들이 일종의 hooking이다. 그런 기법들을 공식적으로 pytorch의 nn.Module에서 지원해준다.
규칙
pytorch의 hook들은 다음과 같은 규칙을 가진다.
- return이 있다면 해당 return을 본래 객체에 적용한다.
- return이 없다면 기존 객체의 동작대로 동작한다.
- hook될 함수는 객체로 전달되기 때문에 아무 이름이나 붙여되 된다.
아래 코드들을 보면 알거다.
tensor hook
tensor는 backward에 대해서만 hook을 지원한다.
torch.tensor.register_hook(function)
nn.Module hook
아래와 같은 4개의 hook을 지원한다.
- register_forward_pre_hook
- register_forward_hook
- register_backward_hook (deprecated)
- register_full_backward_hook
forward_pre_hook 형식
def pre_hook(module, input) return Anything
return이 있다면 forward의 input을 Anything으로 바꿀 수 있다. return이 없다면 단순히 input을 조회할 뿐이다.
forward_hook 형식
def hook(module, input, output) return Anything
return이 있다면 forward의 결과값이 Anything으로 교체된다. return이 없다면 단순 조회.
full_backward_hook
def module_hook(module, grad_input, grad_output)
return이 있다면 backard()를 통해 grad_ouput으로 업데이트 될 때, grad_output을 교체할 수 있다. return이 없다면 단순 교체.
Leave a comment