본문 바로가기
ML & DL/pytorch

[pytorch] 커스텀 연산 with autograd

by 별준 2022. 12. 26.

References

  • Autograd mechanics (link)
  • PyTorch API documentation (link)
  • Extending PyTorch (link)

이번 포스팅에서는 파이토치의 autograd에 대해서 조금 더 세부적으로 살펴보고(주로 공식 홈페이지 내용을 참조하였습니다), 커스텀 확장하는 방법, 즉, 커스텀 op를 구현하는 방법에 알아보도록 하겠습니다. 다룰 내용은 사실 파이토치를 사용하기 위해서 꼭 알아야 하는 내용은 아닙니다만, 알아두면 파이토치의 동작에 대해 조금 더 깊게 이해할 수 있고 디버깅할 때도 유용할 수 있습니다.

 

Autograd Mechanics

How autograd encodes the history

Developer Notes에서 autograd는 reverse automatic differentiation 시스템이라고 언급하고 있습니다. 개념적으로 autograd는 데이터를 만들어내는 모든 연산을 기록하는 그래프를 기록(recording)합니다. 이때 기록되는 그래프는 리프 노드가 입력 텐서이고, 루트 노드가 출력 텐서인 방향이 있는 비순환 그래프(directed acyclic graph)입니다. 루트에서부터 리프 노드까지의 그래프를 추적하여 chain rule을 사용해 그라디언트(gradient)를 자동으로 계산할 수 있습니다.

 

내부적으로 autograd는 Function 객체의 그래프로 나타냅니다. 이 Function 객체는 apply() 메소드를 호출하여 그래프의 결과를 계산하는데 사용할 수 있습니다. Function 객체에 대해서는 아래에서 조금 더 자세히 살펴보도록 하겠습니다.

순방향(forward pass) 연산을 수행할 때 autograd는 요청된 연산을 수행하는 동시에 그라디언트를 연산하는 함수로 나타나는 그래프를 구성합니다. 이때, 구성된 그래프의 entry point는 각 torch.Tensor의 .grad_fn 속성입니다. 순방향 연산이 완료된 후, 역방향(backward pass)으로 그래프를 평가하여 그라디언트를 계산할 수 있습니다.

 

주목해야 할 점은 학습할 때 매 반복(iteration)마다 그래프가 새로 구성된다는 것입니다. 따라서, 매 반복마다 임의의 파이썬 control flow를 사용하여 그래프의 모양이나 크기를 변경할 수 있습니다.

 

Saved tensors

몇몇 연산은 역방향 연산을 수행하기 위해서 순방향 과정의 중간 결과를 저장해야 합니다. 예를 들어, \(x \mapsto x^2\) 함수는 그라디언트를 계산하기 위해서 입력인 \(x\)를 저장합니다.

 

커스텀 Function을 정의할 때, save_for_backward()를 사용하여 순방향 연산 중에 텐서를 저장할 수 있고 saved_tensors를 사용하여 역방향 연산을 할 때 저장된 텐서를 꺼내올 수 있습니다. Function에 관한 것들은 아래에서 더 자세히 살펴보고, 지금은 단편적으로만 살펴보겠습니다.

 

파이토치에서 정의하는 연산, 예를 들어, torch.pow()와 같은 연산에서 필요한 텐서들이 자동으로 저장됩니다. 어떤 텐서가 저장되었는지는 텐서의 grad_fn 속성에서 _saved로 시작하는 속성들로 확인해볼 수 있습니다.

x = torch.rand(5, requires_grad=True)
y = x.pow(2)
print(x.equal(y.grad_fn._saved_self)) # True
print(x is y.grad_fn._saved_self)     # True

Tensor의 equal 메소드는 텐서의 모양이나 값이 같은지 확인해주는데, 입력인 x와 y의 grad_fn에 저장된 텐서가 x와 같다는 것을 보여줍니다. 특히 line 4의 코드에서 y.grad_fn._saved_self는 x와 동일한 Tensor 객체라는 것을 보여줍니다. 하지만, 이런 케이스만 있는 것은 아닙니다.

 

아래 코드에서는 \(y = e^x\) 함수를 사용하고 있습니다.

x = torch.rand(5, requires_grad=True)
y = x.exp()
print(y.equal(y.grad_fn._saved_result)) # True
print(y is y.grad_fn._saved_result)     # False

위 코드에서 순환 참조를 피하기 위해 파이토치는 저장할 때 텐서를 pack 하고, 읽을 때 저장된 텐서를 다른 텐서로 unpack 합니다. 마지막 라인을 보면 저장된 텐서는 y.grad_fn._saved_result로 액세스할 수 있습니다. 하지만 이는 y와 다른 텐서 객체라는 것을 보여줍니다. 객체는 다르지만 동일한 저장 공간을 공유하고 있습니다.

(위의 코드와는 다르게 저장된 텐서를 _saved_result로 꺼내오는데, 그라디언트를 계산할 때 입력이 아닌 결과값이 필요하기 때문으로 보입니다: \(y' = e^x\))

 

텐서가 다른 텐서 객체로 pack되는지 여부는 이 텐서의 grad_fn의 output인지 아닌지에 따라 다르며, 이는 구현 세부사항입니다. 파이토치가 pack/unpack하는 방법은 사용자가 컨트롤할 수 있는데, Hooks for saved tensors에서 확인할 수 있습니다.

 

Gradients for non-differentiable functions

Automatic Differentiation을 사용하는 기울기 계산은 오직 함수의 모든 지점에서 미분 가능할 때만 유효합니다. 하지만, 실제로 사용되는 많은 함수들이 이 조건을 만족하지 못합니다. 예를 들어, relusqrt는 0인 지점에서 미분이 불가능합니다. 이러한 미분이 불가능한 함수를 미분하기 위해서 아래의 규칙을 순서대로 적용하여 elementary operations의 기울기를 정의합니다. 

  1. 함수가 미분이 가능하고, 현재 지점에서 그라디언트가 존재한다면 이를 사용합니다.
  2. 함수가 convex (at least locally) 라면, minimum norm의 sub-gradient를 사용합니다.
  3. 함수가 concave (at least locally) 라면, minimum norm의 super-gradient를 사용합니다.
  4. 함수가 정의된다면, 연속인 현재 지점에서의 기울기를 정의합니다. 여기서는 sqrt(0)에서처럼 inf가 가능합니다. 만약 가능한 값이 여러 개라면, 임의로 하나를 선택합니다.
  5. 만약 함수가 정의되지 않는다면(ex, sqrt(-1), log(-1) 또는 입력이 NaN일 때의 함수들), 임의의 값을 기울기로 사용합니다. 에러를 발생시킬 수도 있지만 항상 에러가 보장되지는 않습니다. 대부분의 함수는 기울기의 값으로 NaN을 사용하지만, 성능상의 이유로 몇몇 함수들은 다른 값을 사용합니다(ex, log(-1)).
  6. 만약 함수가 deterministic mapping이 아니라면, 즉, mathematical function이 아니라면, 이는 non-differentiable로 마크됩니다. 만약 no_grad 컨텍스트 외부에서 기울기를 요구하는 텐서에 사용된다면, 역방향 연산에서 에러를 발생시킵니다.

 

Locally disabling gradient computation

파이썬에서 지역적으로 기울기 계산을 비활성화하는 여러가지 방법이 있습니다.

코드 전체 블록에서 기울기를 비활성화하기 위해서는 no-grad mode와 inference mode와 같은 context managers를 사용할 수 있습니다. 조금 더 세부적으로 컨트롤하려면 텐서의 requires_grad를 설정하면 됩니다. 

추가로 evaluation mode (nn.Module.eval())도 있는데, 이 방법은 기울기 계산을 비활성화하는데 사용되지는 않습니다.

 

Setting requires_grad

nn.Parameter로 래핑되지 않는다면 기본값이 false인 requires_grad는 그래프에서 기울기 연산을 세밀하게 제외시킬 수 있는 플래그입니다. 이는 순방향과 역방향 연산 모두에 영향을 미칩니다.

 

적어도 하나의 입력 텐서가 grad를 요구한다면, 순방향 연산 중에 연산은 backward graph에 기록됩니다. 역방향 연산(.backward()) 중에 오직 requires_grad=True로 설정된 리프 텐서만 이들 텐서의 .grad 필드에 기울기 값을 누적합니다.

중요한 것은 비록 모든 텐서가 이 플래그를 가지고 있더라도 오직 리프 텐서에만 의미가 있습니다. 리프가 아닌 텐서는 리프 텐서와 연관되는 역방향 그래프를 가진 텐서입니다. 따라서, 리프가 아닌 텐서들의 기울기는 grad를 요구하는 리프 텐서들의 기울기를 계산하기 위한 중간 결과로서 필요합니다. 즉, non-leaf 텐서는 자동으로 requires_grad=True가 됩니다.

 

예를 들어, pretrained model을 fine-tuning하는 동안 일부분을 freeze할 때, requires_grad를 설정하여 컨트롤할 수 있습니다. 업데이트할 필요가 없는 파라미터에 .requires_grad_(False)를 적용하면 freeze시킬 수 있습니다. 그러면, freeze된 파라미터를 사용하는 연산은 순방향 연산 중에 기록되지 않으며, 역방향 그래프에 포함되지 않기 때문에 .grad 필드에 기울기를 업데이트하지 않습니다.

 

그리고 nn.Module.requires_grad_()를 사용하여 모듈 레벨에서 requires_grad를 설정할 수 있습니다. 이는 모듈의 전체 파라미터에 영향을 줍니다.

 

Grad Modes

requires_grad를 설정하는 것 외에도 autograd에 의해 내부적으로 처리되는 파이토치의 연산되는 방식에 영향을 주는 세 가지 가능한 모드가 있습니다. 세 가지 모드는 default mode(grad mode), no-grad mode, inference mode이며, 이들은 모두 context manager와 decorator를 통해 토글됩니다.

 

Default Mode (Grad Mode)

"default mode"는 no-grad나 inference 모드가 활성화되지 않는 경우 암묵적으로 선택되는 모드입니다. no-grad 모드와 대조하기 위해서 이 모드는 grad 모드라고도 부릅니다. 이 모드에 대해서 알아야 할 가장 중요한 것은 requires_grad가 적용되는 유일한 모드라는 것입니다. 다른 두 모드에서는 requires_grad가 항상 False로 재정의(override)됩니다.

 

No-Grad Mode

no-grad 모드에서의 연산은 마치 어떠한 입력도 grad를 필요로 하지 않는 것처럼 동작합니다. 즉, 입력 텐서의 requires_grad가 True이더라도 역방향 그래프(backward graph)에서 연산이 절대 기록되지 않습니다. 

autograd에 의해 기록되지 않아야 하는 연산을 수행해야할 때 no-grad 모드를 활성화시키면 됩니다. 하지만, 이후에 grad mode에서 이 연산에서의 출력을 여전히 사용할 수 있습니다. 이 context manager는 임시로 텐서의 requires_grad를 False로 설정하고, 다시 True로 설정할 필요없이 편리하게 함수 코드 블록에서 그라디언트를 비활성화할 수 있게 해줍니다.

 

Inference Mode

inference 모드는 no-grad 모드의 극단적인 버전입니다. no-grad 모드와 동일하게 inference 모드의 연산은 역방향 그래프에 기록되지 않지만 파이토치가 모델의 속도를 좀 더 향상시킬 수 있습니다. 단, inference 모드에서 생성된 텐서는 autograd에 의해 기록되는 연산에 사용될 수 없습니다.

inference 모드는 autograd tracking이 필요없는 data processing이나 model evaluation과 같은 코드에서 사용하는 것을 권장합니다.

 

Inferece 모드에 대한 조금 더 자세한 내용은 link를 참조바랍니다.

 

Evaluation Mode (nn.Module.eval())

evaluation 모드는 지역적으로 기울기 연산을 비활성화하는 메커니즘은 아닙니다.

기능적으로 module.eval() (or module.train(False))은 완전히 no-grad 모드와 inference 모드에 직교한다고 할 수 있습니다. model.eval()이 모델에 영향을 미치는 방법은 완전히 모델에서 사용되는 모듈과 학습 모드(training-mode)에 정의한 구체적인 동작에 따라 

 

만약 모델이 torch.nn.Dropouttorch.nn.BatchNorm2d와 같은 모듈을 사용한다면, model.eval()model.train()을 호출할 때 동작이 달라집니다. 이들을 학습 모드라고 하며, 이 학습 모드에 따라 동작이 달라지는데, 예를 들어, validation data에 대해 동작할 때 BatchNorm을 업데이트하지 않도록 해줍니다.

 

학습할 때는 항상 model.train()을 사용하고, 모델을 평가(validation/testing)할 때는 모델이 training-mode에 지정된 동작이 확실하지 않더라도 model.eval()을 사용하는 것을 권장합니다 (사용하는 모델이 training과 eval 모드에서 다르게 업데이트될 수 있기 때문).

 

 

In-place operations with autograd

autograd에서 in-place 연산을 지원하는 것은 어려운 문제이며, 대부분의 경우 사용하지 않는 것을 권장합니다. autograd의 공격적인 버퍼 해제(buffer freeing)와 재사용(reuse)는 매우 효율적이며, 상당히 많이 사용하는 메모리 양을 줄이는 경우는 거의 없습니다. 만약 메모리 사용을 최대한 줄여야하는 경우가 아니라면 이를 사용할 필요는 거의 없습니다.

 

autograd에서 in-place 연산을 제한하는 두 가지 이유는 다음과 같습니다.

  1. In-place 연산은 잠재적으로 그라디언트를 계산하도록 요구되는 값을 덮어쓸 수 있습니다.
  2. 모든 in-place 연산은 실제로 computational graph를 재작성하는 구현을 필요로 합니다. Out-of-place 버전을 새로운 객체를 할당하고 old graph에 대한 참조를 유지하면 되지만, in-place 연산은 이 연산을 나타내는 Function에 대한 모든 입력의 creator를 변경해주어야 합니다. 특히, 동일한 storage를 참조하는 텐서가 많으면 까다로울 수 있으며, 수정된 입력의 storage가 다른 Tensor가 참조하면 에러를 발생시킬 수 있습니다.

In-place correctness checks

모든 텐서는 어떤 연산에 대해 dirty(?)로 마크될 때마다 증가하는 version counter를 기록합니다. 한 Function이 역방향 연산에서 어떤 텐서를 저장할 때, 포함되는 텐서의 version counter 또한 저장됩니다. self.saved_tensors애 엑세스할 때 verison counter는 체크되며, 이 값이 만약 저장된 값보다 크다면 에러가 발생됩니다. in-place 함수를 사용할 때 이 에러가 발생하지 않는다면, 계산된 기울기가 올바르다는 것을 보장할 수 있습니다.

이외에 추가 내용들이 있지만, 제가 아직 완전히 이해하지 못한 부분들도 있고 다루는 범위가 너무 커져서 이후의 내용은 다루지 않았습니다. 혹시 더 알고 싶다면, link를 참조바랍니다.

 


Automatic Differentiation Package - torch.autograd

이번에는 automatic differentiation을 지원하는 패키지 문서에 대해서 간단히 살펴보겠습니다 (전체 내용은 link를 참조바랍니다).

torch.autograd 패키지는 임의의 스칼라 값 함수의 automatic differentiation을 구현하는 클래스와 함수들을 제공합니다. 이미 존재하는 코드에 약간의 변화만 주면 되는데, 단지 requires_grad=True 키워드가 있을 때 그라디언트를 계산하는 Tensor만 선언하면 됩니다 (backwardgrad). 현재까지 autograd는 부동소수점(half, float, double, bfloat16)과 복소수 타입(cfloat, cdouble)의 텐서에만 지원됩니다.

 

Tensor autograd functions

torch.Tensor의 autograd 관련 API는 아래와 같습니다.

 

Function

Function은 automatic differentiation을 커스텀하기 위해 필요한 베이스 클래스입니다. 커스텀 autograd.Function을 구현하기 위해서는 이 클래스를 서브클래싱해야 하며, forward()backward()라는 static method를 구현해야 합니다. 그리고 나서 순방향 연산에서 구현한 커스텀 op를 사용하려면 직접 forward()를 호출하지 말고, 클래스 메소드인 apply를 호출하면 됩니다.

 

올바르고 최상의 성능을 보장하기 위해서, ctx(Context)에 대한 올바른 메소드를 호출하고 torch.autograd.gradcheck()를 사용하여 역방향 함수를 검증해야 합니다. 이에 대해서는 아래에서 커스텀 torch.autograd를 직접 작성해보면서 더 자세히 살펴보도록 하겠습니다.

 

커스텀 Function은 다음과 같이 구현되고, 사용될 수 있습니다.

from torch.autograd import Function

class Exp(Function):
    @staticmethod
    def forward(ctx, i):
        result = i.exp()
        ctx.save_for_backward(result)
        return result
    
    @staticmethod
    def backward(ctx, grad_output):
        result, = ctx.saved_tensors
        return grad_output * result

# Use it by calling the apply method
output = Exp.apply(input)

 

Context method mixins

Function을 새로 구현할 때, ctx에 대해서 사용가능한 메소드는 다음과 같습니다.

 

Numerical gradient checking

커스텀 Function을 구현할 때, 역방향 연산을 검증해야 한다고 언급했었습니다. 이는 torch.autograd에서 제공하는 아래의 API를 사용하여 검증할 수 있습니다.

 


Extending torch.autograd

autograd를 지원하는 연산을 추가하려면 Function을 서브클래싱하는 연산을 구현해야 합니다. 먼저 backward mode AD(automatic differentiation)에 대해서 살펴보고, 마지막에 forward mode AD에 대해 논의해보도록 하겠습니다.

 

When to use

일반적으로 미분이 불가능하거나 파이토치 라이브러리에 의존하지 않는 연산을 사용해야 하지만 이 연산이 다른 연산들(ops)과 엮이고 autograd engine에서 동작하기를 원할 때, 커스텀 함수를 구현합니다. 잘만 구현한다면 커스텀 함수는 성능과 메모리 사용을 향상시킬 수도 있습니다.

 

When not to use

만약 이미 파이토치에서 제공하는 built-in ops만을 사용하여 함수를 구현한다면, 대부분의 경우 이 연산의 backward graph는 autograd에 의해 기록됩니다. 따라서, 이 경우에는 backward function을 직접 구현할 필요가 없습니다.

 

만약 state (즉, trainable parameters)를 유지해야 한다면, 커스텀 모듈을 사용할 수 있으며 이는 torch.nn을 확장하면 됩니다.

 

이외에 역방향 연산 중에 그라디언트를 변경하거나 다른 효과를 수행하도록 하려면 tensor hook 또는 Module hook을 등록하는 것을 고려하면 됩니다 (저도 아직 자세히 알지는 못하는 부분이라 이번 포스팅에서는 간단하게 이런 것이 있다는 것만 언급하고 넘어가겠습니다).

 

How to use

커스텀 Function 구현은 다음의 순서를 따라서 구현하면 됩니다.

  1. Function을 서브클래싱하고 forward()와 backward() 메소드를 구현합니다.
  2. ctx 인자의 적절한 메소드를 호출합니다.
  3. double backward를 지원하는 함수인지 선언합니다. (이번 포스팅에서 double backward는 다루지 않습니다)
  4. gradcheck를 사용하여 구현한 함수의 그라디언트가 올바른지 검증합니다.

위의 각 단계를 조금 더 세부적으로 살펴보겠습니다.

Step 1

Function을 서브클래싱하면, forward()와 backward() 함수를 구현해야 합니다.

 

forward()는 연산(operation)을 수행하는 함수입니다. 이 함수는 원하는 만큼 인자를 받을 수 있고, (default 값이 있는)optional 인자도 받을 수 있습니다. 여기에는 모든 종류의 파이썬 객체가 인자로 전달될 수 있습니다. requires_grad=True로 설정되어 기록을 추적하는 Tensor 인자는 forward() 호출 이후에는 추적되지 않도록 변환됩니다. 그리고 이 텐서는 그래프에 등록됩니다. 이러한 로직은 list,dict, 또는 다른 데이터 구조에는 적용되지 않으며 forward() 호출의 직접적인 인수인 텐서만 적용됩니다.

forward() 함수는 여러 개의 출력이 있다면 텐서들의 튜플로 반환할 수 있습니다.

 

backward()(또는 vjp())는 gradient formula를 정의합니다. 출력 갯수 만큼의 Tensor 인자가 전달되고, 이는 각각의 출력에 대한 기울기를 나타냅니다. 여기서 in-place로 수정하지 않도록 하는 것이 중요합니다.

backward() 함수는 입력 수 만큼의 텐서를 반환해야 하며, 이 텐서는 대응하는 입력에 대한 그라디언트를 포함합니다. 만약 입력이 그라디언트를 필요로 하지 않거나, Tensor 객체가 아닌 경우에는 None을 반환하면 됩니다 (needs_input_grad를 통해 그라디언트 연산이 필요한지 알 수 있습니다. 아래의 예제 참조).

Step 2

새로 구현한 Function이 autograd 엔진으로 잘 동작하려면 forward() 함수에서 ctx의 함수를 적절하게 잘 사용해야 합니다. 사용할 수 있는 ctx의 함수는 아래에 나열되어 있습니다.

  • save_for_backward() : 역방향 연산에서 사용될 텐서들을 저장하는데 사용합니다. 텐서가 아닌 것들은 ctx에 바로 저장되어야 합니다. 만약 입력이나 출력이 아닌 텐서가 저장되면 double backward를 지원하지 않을 수 있습니다 (Step 3 참조).
  • make_dirty() : forward 함수에서 in-place로 수정되는 입력을 마킹하는데 사용됩니다.
  • make_non_differentiable() : 출력이 미분 불가능하다고 엔진에게 알려주는데 사용됩니다. 기본값으로 모든 출력 텐서는 미분가능한 타입이며, 그라디언트를 요구하도록 설정됩니다. 미분 불가능한 타입(e.g., integral types)의 텐서는 절대 그라디언트를 요구하도록 마킹되지 않습니다.
  • set_materialize_grads() : 출력이 입력에 의존하지 않는 경우에 backward() 함수에 주어진 grad tensors를 사용하지 않도록 하여 기울기 연산을 최적화하도록 엔진에게 알려주는데 사용합니다. 즉, False로 설정한다면, 파이썬의 None 객체 또는 C++의 "undefined tensor"(tensor.defined()가 False인 tensor)는 backward()를 호출하기 전에 0으로 채워진 텐서로 변환되지 않습니다. 그래서 코드에서는 이러한 객체를 마치 0으로 채워진 텐서인 것처럼 처리해야 합니다. 기본 설정 값은 True 입니다.

Step 3

만약 구현할 Function이 double backward를 지원하지 않는다면, 명시적으로 backward 함수에 once_differentiable()로 데코레이트해야 합니다. 이 데코레이터가 있으면, 함수 내에서 double backward를 시도할 때 에러가 발생됩니다. double backward에 대해서는 이번 포스팅에서 다루지 않으며, 공식 문서(link)를 참조바랍니다.

Step 4

구현한 backward 함수가 forward의 그라디언트를 올바르게 계산하는지 torch.autograd.gradcheck()를 사용하여 확인하는 것을 권장합니다. 이 함수는 backward 함수를 사용하여 야코비안 행렬을 계산하고, finite-differencing을 사용하여 요소별 값을 앞서 계산된 야코비안과 비교합니다.

 

Example 1

아래 예제 코드는 간단한 Linear 함수를 구현합니다.

# Inherit from torch.autograd.Function
class LinearFunction(torch.autograd.Function):
    # Note that both forward and backward are @staticmethod
    
    @staticmethod
    # bias is an optional argument
    def forward(ctx, input, weight, bias=None):
        ctx.save_for_backward(input, weight, bias)
        output = input.mm(weight.t()) # y = wx
        
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output) # y += b (y = wx + b)
        
        return output
    
    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_output):
        # unpack saved_tensors
        input, weight, bias = ctx.saved_tensors
        # initialize all gradients with respect to inputs to None
        grad_input = grad_weight = grad_bias = None
        
        # using needs_input_grad to check where each input needs gradient computation
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)
        
        return grad_input, grad_weight, grad_bias

Linear 함수의 backpropagation에 대해서 간단하게 설명하는 페이지가 있어서 링크 남겨둡니다. 단, 아래 링크에서는 코드와 비교해서 weight의 차원이 transpose되어 있음에 유의하시길 바랍니다.

 

Backpropagation for a Linear Layer

Website for UMich EECS 442 course

web.eecs.umich.edu

 

이렇게 구현한 커스텀 ops는 apply 메소드를 사용하여 사용할 수 있습니다.

linear = LinearFunction.apply

 

Example 2

non-Tensor 인자를 받는 함수의 경우, 구현은 다음과 같습니다. 아래에서 구현한 연산은 단순히 텐서에 상수를 곱하는 연산입니다.

class MulConstant(torch.autograd.Function):
    @staticmethod
    def forward(ctx, tensor, constant):
        # ctx can be used to stash information for backward computation
        ctx.constant = constant
        return tensor * constant
    
    @staticmethod
    def backward(ctx, grad_output):
        # Note that we return as many input gradients as there were arguments.
        # And, gradients of non-Tensor arguments to forward must be None.
        return grad_output * ctx.constant, None

텐서가 아닌 인자는 ctx에 직접 저장되는 것을 볼 수 있습니다. 그리고 backward() 함수는 forward()의 입력 수 만큼 그라디언트를 반환해야 하는데, 텐서가 아닌 입력에 대해서는 None을 반환합니다.

 

위의 코드에서 그라디언트 계산에 입력은 사용되지 않습니다. 따라서, 이 경우 set_materialize_grads(False)를 호출하여 최적화할 수 있습니다.

class MulConstant(torch.autograd.Function):
    @staticmethod
    def forward(ctx, tensor, constant):
        ctx.set_materialize_grads(False)
        ctx.constant = constant
        return tensor * constant
    
    @staticmethod
    def backward(ctx, grad_output):
        # Here we must handle None grad_output tensor.
        # In this case we can skip unnecessary computations
        if grad_output is None:
            return None, None
        
        return grad_output * ctx.constant, None

 

backward() 함수의 입력, 즉, grad_output 또한 기록을 추적하는 텐서일 수 있습니다. 

 

torch.autograd.gradcheck

구현한 backward 함수가 올바르게 구현되었는지 확인하기 위해서 torch.autograd.gradcheck를 사용할 수 있습니다. 만약 고차 도함수를 사용한다면 gradgradcheck를 사용하여 확인할 수 있습니다.

from torch.autograd import gradcheck

# gradcheck takes a tuple of tensors as input,
# check if your gradient evaluated with these tensors are close enough to numerical approximations and
# returns True if they all verify this condition
input = (torch.randn(20, 20, dtype=torch.double, requires_grad=True),
        torch.randn(30, 20, dtype=torch.double, requires_grad=True))
test = gradcheck(linear, input, eps=1e-6, atol=1e-4)
print(test)

 

Extending torch.nn

위에서 구현한 커스텀 op인 LinearFunction는 주로 torch.nn.Module을 확장하여 nn.Module 내에서 사용하는 경우 많을 것 같습니다. Module을 확장하는 것은 아마 잘 알려져 있고, 네트워크 모델을 구현에서 이를 사용하는 예제가 많이 보입니다. 마지막으로 위에서 구현한 커스텀 op를 사용하는 예제 코드를 아주 간단히만 살펴보겠습니다.

 

nn 패키지는 autograd를 아주 많이 활용하기 때문에 새로운 nn.Module을 구현하기 위해서는 연산을 수행하고 그라디언트를 계산할 수 있는 Function이 필요합니다. 아래에서 볼 예제 코드에서는 Linear 모듈을 구현하며, 아래의 두 함수만 구현해주면 됩니다.

  • __init__ (optional) : 커널의 사이즈나 feature의 수 등의 초기화해야할 파라미터나 버퍼가 있는 경우 구현합니다.
  • forward() : Function을 인스턴스화하고 이를 사용하여 연산을 수행합니다. 이는 Function의 래퍼라고 볼 수 있습니다.

구현 예시는 다음과 같습니다.

class Linear(torch.nn.Module):
    def __init__(self, input_features, output_features, bias=True):
        super(Linear, self).__init__()
        self.input_features = input_features
        self.output_features = output_features
        
        # nn.Parameter is a special kind of Tensor, that will get
        # automatically registered as Module's parameter once it's assigned
        # as an attribute. Parameters and buffers need to be registered, or
        # they won't appear in .parameters() (doesn't apply to buffers), and
        # won't be converted when e.g. .cuda() is called. You can use
        # .register_buffer() to register buffers.
        # nn.Parameters require gradients by default
        self.weight = torch.nn.Parameter(torch.empty(output_features, input_features))
        if bias:
            self.bias = torch.nn.Parameter(torch.empty(output_features))
        else:
            # You should always register all possible parameters, but the
            # optional ones can be None if you want.
            self.register_parameter('bias', None)
        
        # Not a very smart way to initialize weights
        torch.nn.init.uniform_(self.weight, -0.1, 0.1)
        if self.bias is not None:
            torch.nn.init.uniform_(self.bias, -0.1, 0.1)
    
    def forward(self, input):
        return LinearFunction.apply(input, self.weight, self,bias)
    
    def extra_repr(self):
        # (optional) Set the extra information about this module.
        # You can test it by printing an object of this class
        return f'input_features={self.input_features}, output_features={self.output_features}, bias={self.bias is not None}'
linear = Linear(10,20)
print(linear)

 

댓글