본문 바로가기
딥러닝

[KD] On-Policy Distillation of Language Models: learning from Self-Generated Mistakes

by 방구석 데이터사이언티스트 2024. 12. 31.
728x90
반응형

Instruction

일반적으로 student는 teacher보다 더 적은 파라미터를 가지므로, distillation을 통해 teacher보다 적은 추론 비용과 메모리 풋프린트을 유지하면서 특정 작업의 성능을 향상시킬 수 있다. auto-regrressive sequence model을 위한 현재의 distillation 방법은 비용이 많이 들 수 있는 teacher에서 fixed set of output sequences을 생성하거나 teacher가 token-level 확률을 레이블로 할당하는 fixed 시퀀스 데이터를 생성해야한다.

 

  • 그러나 fixed 데이터셋은 훈련 중 보이는 출력 과 추론 중 생성하는 출력의 분포 불일치를 이끌 수 있다.
  • 더불어, 학생은 teacher의 분포를 맞추기에 표현력이 충분하지 않을 수 있다. 이는 학생이 teacher에서 생성하지 않는 샘플을 생성할 수도 있다는 점이다.

이 논문은 위 문제를 완화하기 위해 Generalized KD (GKD)을 제안한다.

  1. auto-regrressive sequence model을 위한 KD는 대화형 전문가와의 모방 학습 문제(task or problem)로 본다. GKD는 teacher의 확률을 expert labels로서 사용하고, fixed set of outputs sequences 대신 self-generated sequences를 학습한다.
  2. GKD는 다양한 divergence에 최적화할 수 있는 유연성을 보여준다.

Contribution

  • auto-regressive LM의 학습과 추론 동안의 불일치를 해결하기 위해, GKD를 제안한다. GKD는 증류를 위해 the token-level teacher probabilities 따른 on-policy student-generated outputs을 활용한다.
  • in task-specific (위 선 그래프) and task-agnostic KD (위 막대 그래프)에서 흔히 사용되는 방법보다 뛰어나다.
  • on-policy GKD는 RL fine-tuning과 원활하게 결합할 수 있다.
  • student-generated on-policy output sequences를 사용하는 것의 중요성과 optimal divergence의 작업-의존적 특징에 대한 실질적인 인사이트를 제공한다.

Method : DISTILLATION FOR AUTO-REGRESSIVE SEQUENCE MODELS

  • 학습과 추론에 대한 분포 불일치는 student의 추론 과정에서 생성된 sequence 때문이며, 초기 잘못된 생성은 이후 예측에 영향을 미치는 cascading effect를 가진다. → 이 문제를 imitation learning 통해 해결하려한다.
  • on-policy imitation approaches는 반복적으로 student pilicy를 통해 sequences를 수집하고, 그sequences를 위한 expert labels을 획득하여, 그 데이터로 student를 재학습한다.

on-policy imitation을 distillation으로 확장!

더 나아가, supervised KD 와 on-policy KD를 통합하여 좀더 일반적인 방법인, GKD를 제안한다.

  • GKD는 최적화하기 위한 divergence와 훈련하기 위한 output sequence를 모두 선택할 수 있다.
    • teacher와 student의 token-level probability distribution에 대한 여러 divergence를 최적화 할 수 있다.
    • GKD는 ground-truth, teacher-generated, on-policy student-generated sequences를 혼합한 데이터를 사용한다.
  • Choice of Divergence in GKD : Forward KD는 mean-seeking(mode-covering)하며, Reverse KD는 mode-seeking한 특성을 지닌다. 각 태스크 마다 최적화의 divergence가 다를 수 있다. GKD divergence를 선택할 때, 특정 태스크에 필요한 다양성과 성능의 트레이드 오프를 고려해야한다.

EXPERIMENTS

Student / Teacher Models (goole 논문이라 그런지 T5만 사용함)

  • Teacher: supervised fine-tuned T5-XL (∼ 3B params)
  • Students : T5-small (77M params), T5-base (250M params), T5-large (800M params)

GKD Variants.

  • forward KL, reverse KL, three variants of JSD(β): JSD (0.1), JSD (0.5) and JSD (0.9)
  • student data fraction을 위해, λ = 1 (On-policy), λ = 0.5 (Mixed) and λ = 0 (Supervised) 등 여러 설정 실험. → 이전에 탐구가 없었던 λ = 1 (On-policy)가 주요 관심 요소.

abstractive summarization task

  • On-policy GKD with JSD (0.9)는 greedy sampling과 temperature sampling (γ = 1)에서 모두 베이스라인을 능가함.(Figure 2)
  • ground-truth summaries가 포함되지않은 5%의 일부 데이터셋의 on-policy GKD 결과가 ground-truth summaries가 포함된 전체 데이터셋의 supervised KD와 ImitKD의 결과를 능가함. (Figure 3)
  • 더 많지만 여기까지..

기타

728x90
반응형

댓글