Per-example gradients: a new frontier for understanding and improving optimizers

이 논문은 자동미분 그래프 수정이나 JAX 의 벡터화를 통해 계산 및 메모리 오버헤드 없이 개별 예제 기울기를 효율적으로 계산할 수 있음을 보이며, 이를 통해 SGD 와 Adam 과 같은 최적화 알고리즘의 설계 원리를 재검토하고 새로운 분석 및 알고리즘 개발의 가능성을 제시합니다.

Vincent Roulet, Atish Agarwala

게시일 2026-03-03
📖 3 분 읽기☕ 가벼운 읽기

Each language version is independently generated for its own context, not a direct translation.

🍽️ 비유: 식당의 요리사와 메뉴판

딥러닝 모델을 훈련시킨다는 것은, 수많은 학생 (데이터) 을 가르쳐서 한 명의 천재 (모델) 를 만드는 과정과 비슷합니다. 이때 '최적화 알고리즘'은 학생들의 답을 보고 "어디가 틀렸는지" 지적해주는 선생님 역할을 합니다.

1. 기존의 방식: "평균 점수만 보는 선생님"

기존의 딥러닝 훈련 방식은 **미니배치 (Mini-batch)**라는 작은 그룹의 학생들 (예: 64 명) 을 한 번에 봅니다. 그리고 이 64 명의 **평균 점수 (평균 기울기)**만 계산해서 "너희 전체가 이 방향으로 공부해라"라고 지시합니다.

  • 문제점: 64 명 중 60 명은 잘 풀었는데, 4 명만 완전히 엉뚱한 답을 썼다면? 평균을 내면 그 4 명의 엉뚱한 정보가 사라져버립니다. 혹은 4 명은 아주 잘했는데 60 명이 엉망이라면, 그 4 명의 재능이 묻힙니다.
  • 기존의 생각: "개별 학생의 오답을 하나하나 분석하는 건 너무 귀찮고, 컴퓨터 메모리도 너무 많이 먹어서 불가능해."라고 생각했습니다.

2. 이 논문의 혁신: "개별 학생의 오답 노트를 보는 선생님"

이 논문은 **"아니, 개별 학생 (Per-example) 의 오답 노트를 하나하나 분석하는 게 그렇게 어렵지 않아!"**라고 말합니다.

  • 기술적 비유 (JAX 와 '수술'):
    컴퓨터가 계산을 할 때, 중간 과정들을 기록해둡니다. 보통은 이 기록들을 다 합쳐서 평균만 내버리는데, 이 논문은 **"아, 이 중간 기록들 (계산 그래프) 을 살짝 '수술'해서, 평균을 내기 전에 개별 학생들의 오답 노트를 따로 떼어낼 수 있구나!"**라고 발견했습니다.
    • 마치 식당에서 64 개의 접시를 한 번에 씻는 대신, 접시 하나하나의 얼룩을 따로 찍어서 분석할 수 있는 새로운 세척기를 개발한 것과 같습니다.
    • 놀랍게도 이 방법을 쓰면 컴퓨터 메모리나 속도가 크게 느려지지 않습니다. (특히 최신 AI 모델인 '트랜스포머' 구조에서는 거의 비용이 들지 않아요.)

3. 발견한 두 가지 중요한 사실

이제 개별 학생의 오답 노트를 볼 수 있게 되자, 두 가지 놀라운 사실을 발견했습니다.

① '부호 (Sign)'를 언제 찍어야 할까? (SignSGD)

  • 상황: 학생들의 답이 너무 복잡해서, "맞았으면 (+), 틀렸으면 (-)"만 보고 가르치기로 했습니다. (부호만 사용하는 최적화)
  • 발견: "틀린 답"을 고칠 때, 64 명을 다 합쳐서 평균을 낸 뒤에 "부호"를 찍는 것이 가장 좋습니다.
  • 이유: 개별 학생의 오답 노트를 먼저 보고 부호를 찍으면, 그 학생의 '우연한 실수 (노이즈)'까지 그대로 반영되어 혼란을 줍니다. 하지만 64 명을 합쳐 평균을 내면 우연한 실수는 사라지고 진짜 '방향'만 남습니다.
    • 결론: "일단 다 합쳐서 평균을 내고, 그다음에 방향을 정해라."

② '분산'보다 '평균의 제곱'이 더 중요하다 (Adam)

  • 상황: 기존에 유명한 'Adam'이라는 알고리즘은 학생들의 답이 얼마나 '흩어져 있는지 (분산)'를 중요하게 여겼습니다. "답이 들쑥날쑥하면 조심해야지"라는 논리입니다.
  • 발견: 이 논문의 실험 결과, **분산 (흩어짐) 보다는 '평균의 제곱 (진짜 방향의 힘)'**이 훨씬 더 중요합니다.
  • 비유: 64 명이 모두 "왼쪽으로 가자"라고 말하면 (평균이 강함), 그 방향이 진짜입니다. 하지만 32 명은 "왼쪽", 32 명은 "오른쪽"이라고 말하면 (분산이 큼), 평균은 0 이 되어 방향을 잃습니다.
    • 결론: "답이 흩어지는지 (분산) 보다, 진짜 방향이 얼마나 강한지 (평균의 제곱) 를 더 믿어라." 기존 상식과 정반대이지만, 실험 결과 이것이 더 빠르고 안정적으로 모델을 가르칩니다.

💡 요약: 왜 이것이 중요한가?

  1. 가능성 증명: "개별 데이터를 분석하는 건 너무 비싸다"는 편견을 깨뜨렸습니다. 현대적인 컴퓨터 기술 (JAX 등) 을 쓰면 거의 비용 없이 가능합니다.
  2. 새로운 통찰: 개별 데이터를 분석하면, 우리가 몰랐던 최적화 알고리즘의 비밀 (부호를 언제 찍을지, 흩어짐보다 방향을 믿을지) 을 찾아낼 수 있습니다.
  3. 미래: 이제 연구자들은 개별 데이터의 정보를 활용하여 더 빠르고, 더 똑똑한 AI 훈련 방법을 개발할 수 있는 새로운 문을 열었습니다.

한 줄 요약:
"AI 를 가르칠 때, 학생들 전체의 '평균 점수'만 보지 말고, '개별 오답 노트'를 살짝 훑어보는 것이 훨씬 더 똑똑하고 빠른 학습을 가능하게 합니다."

이런 논문을 받은편지함으로 받아보세요

관심사에 맞는 일간 또는 주간 다이제스트. Gist 또는 기술 요약을 당신의 언어로.

Digest 사용해 보기 →