연합 학습(Federated Learning): 사용자 데이터를 공유하지 않고 AI를 학습시키는 방법

연합 학습의 개념도

기존의 머신러닝 파이프라인에서 데이터 수집은 가장 먼저 수행되며 종종 가장 비용이 많이 드는 단계입니다. 모델을 학습시키려면 사진, 문자 메시지, 건강 기록, 금융 거래 등 사용자의 원래 데이터를 수집하여 중앙 클라우드 서버에 업로드해야 합니다.

이러한 중앙 집중식 접근 방식은 AI 혁명을 이끌어 왔지만 다음과 같은 중대한 과제에 직면해 있습니다.

  1. 개인정보 보호 우려: 사용자는 개인 데이터를 타사 서버에 업로드하는 것을 점점 더 꺼려하고 있습니다.
  2. 데이터 규제: GDPR 및 HIPAA와 같은 규정은 개인 데이터의 전송 및 저장 방식을 엄격하게 제한합니다.
  3. 대역폭 비용: 수백만 대의 에ッジ 디바이스(스마트폰 등)에서 수 기가바이트의 원시 데이터를 업로드하는 것은 매우 비효율적입니다.

**연합 학습(Federated Learning - FL)**은 기존의 패러다임을 완전히 뒤집어 이 문제를 해결합니다. 데이터를 모델이 있는 곳으로 가져오는 대신 모델을 데이터가 있는 곳으로 가져갑니다.


핵심 개념: 탈중앙화 학습

연합 학습에서 중앙 서버는 글로벌 모델을 관리합니다. 이 모델을 학습시키기 위해 원시 데이터를 수집하는 대신, 서버는 스마트폰, 스마트 홈 기기 또는 지역 병원 데이터베이스와 같은 에ッジ 디바이스(클라이언트) 네트워크 전체에서 협력적인 학습 프로세스를 조율합니다.

연합 학습의 가장 기본적인 규칙은 다음과 같습니다.

원시 데이터는 로컬 디바이스를 절대 떠나지 않습니다. 오직 수학적인 모델 업데이트 정보만 공유됩니다.


단계별 가이드: 작동 방식

전형적인 연합 학습의 학습 주기(통신 라운드라고 함)는 다음과 같이 5가지 주요 단계로 구성됩니다.

sequenceDiagram
    participant Server as 중앙 서버 (글로벌 모델)
    participant ClientA as 클라이언트 A (개인 데이터 A)
    participant ClientB as 클라이언트 B (개인 데이터 B)
    
    rect rgb(240, 248, 255)
        Note over Server: 단계 1: 글로벌 모델 초기화
    end
    Server->>ClientA: 단계 2: 글로벌 가중치(W_t) 전송
    Server->>ClientB: 단계 2: 글로벌 가중치(W_t) 전송
    rect rgb(245, 245, 245)
        Note over ClientA: 단계 3: 개인 데이터로 로컬 학습
        Note over ClientB: 단계 3: 개인 데이터로 로컬 학습
    end
    ClientA->>Server: 단계 4: 로컬 업데이트(W_t^A) 전송
    ClientB->>Server: 단계 4: 로컬 업데이트(W_t^B) 전송
    rect rgb(240, 255, 240)
        Note over Server: 단계 5: 업데이트 평균화(FedAvg)<br/>글로벌 모델 업데이트(W_t+1)
    end

1. 초기화

중앙 서버는 시작 가중치($W_0$)로 글로벌 모델을 초기화합니다. 이 가중치는 무작위이거나 공개 데이터셋에서 사전 학습된 것일 수 있습니다.

2. 배포(모델 브로드캐스트)

서버는 사용 가능한 클라이언트 디바이스 중 일부(예: 충전 중이고 Wi-Fi에 연결되어 있으며 유휴 상태인 스마트폰)를 선택하여 현재 글로벌 모델 가중치($W_t$)를 전송합니다.

3. 로컬 학습

선택된 각 클라이언트는 수신한 글로벌 모델을 자체 로컬 개인 데이터셋에서 학습시킵니다. 이는 확률적 경사하강법(SGD)과 같은 표준 최적화 알고리즘을 사용하여 수행됩니다. 몇 차례의 에포크(epochs)를 거친 후, 각 클라이언트 $i$는 새로운 로컬 모델 가중치($W_t^i$)를 생성합니다.

4. 로컬 업데이트 업로드

개인 학습 데이터를 보내는 대신, 클라이언트는 새로운 로컬 모델 가중치(또는 차이 $\Delta W_t^i = W_t^i - W_t$)만 중앙 서버로 전송합니다. 이 업데이트는 일반적으로 암호화 프로토콜을 사용하여 안전하게 전송됩니다.

5. 글로벌 집계

중앙 서버는 참여하는 모든 클라이언트로부터 업데이트를 수집합니다. 이를 평균화(일반적으로 각 클라이언트가 가진 로컬 데이터의 양에 비례하여 가중치 부여)하여 새로운 글로벌 모델($W_{t+1}$)을 생성합니다. 가장 많이 쓰이는 알고리즘은 **Federated Averaging(FedAvg)**입니다.

$$W_{t+1} = \sum_{i=1}^{K} \frac{n_i}{N} W_t^i$$

여기서:

  • $K$는 참여 클라이언트 수입니다.
  • $n_i$는 클라이언트 $i$의 데이터 샘플 수입니다.
  • $N$은 모든 참여 클라이언트의 총 데이터 샘플 수입니다 ($N = \sum n_i$).

이 주기는 글로벌 모델이 원하는 정확도에 도달할 때까지 여러 라운드에 걸쳐 반복됩니다.


코드 예제: 간단한 Python 시뮬레이션

연합 학습이 실제로 어떻게 작동하는지 확인하기 위해 NumPy를 사용하여 간단한 Python 시뮬레이션을 작성해 보겠습니다.

이 시나리오에서는 선형 회귀를 사용하여 주택 가격($y = w \cdot x$)을 예측하는 모델을 학습하고자 합니다. 중앙 서버와 3개의 클라이언트가 있으며, 각 클라이언트는 자체의 개인 주택 크기($x$)와 가격($y$) 데이터를 가지고 있습니다.

import numpy as np

# 1. 클라이언트 개인 데이터 설정 (서버와 공유 불가능)
# 각 클라이언트는 서로 다른 수의 로컬 샘플(n_i)을 가집니다.
clients_data = {
    "Client_1": {"x": np.array([1.0, 1.5, 2.0]), "y": np.array([110.0, 160.0, 210.0])}, # 실제 관계: y = 100x + 10
    "Client_2": {"x": np.array([0.8, 1.2]),       "y": np.array([90.0, 130.0])},       # 실제 관계: y = 100x + 10
    "Client_3": {"x": np.array([2.5, 3.0, 3.5]), "y": np.array([260.0, 310.0, 360.0])}  # 실제 관계: y = 100x + 10
}

# 모든 클라이언트의 총 데이터 포인트 수 (N)
total_samples = sum(len(data["x"]) for data in clients_data.values())

# 2. 서버 초기 가중치 (글로벌 모델: W_t)
global_weight = 10.0  # 초기 추측값 (실제 값 100.0과 매우 다름)
learning_rate = 0.05
epochs = 5  # 라운드당 로컬 학습 에포크 수
communication_rounds = 3

print(f"초기 글로벌 가중치: {global_weight:.2f}\n")

# 연합 학습 루프
for round_idx in range(communication_rounds):
    print(f"--- 통신 라운드 {round_idx + 1} ---")
    local_weights = []
    client_sample_sizes = []
    
    # 단계 2 & 3: 모델 배포 및 클라이언트 기기에서의 로컬 학습
    for client_name, data in clients_data.items():
        x = data["x"]
        y = data["y"]
        n_i = len(x)
        
        # 클라이언트는 글로벌 가중치를 수신
        w_local = global_weight
        
        # 클라이언트는 로컬에서 몇 에포크 동안 학습 진행
        for epoch in range(epochs):
            # 예측값 계산: y_pred = w * x
            y_pred = w_local * x
            # 단순 선형 회귀에 대한 경사(gradient) 계산
            gradient = -2 * np.mean(x * (y - y_pred))
            # 로컬 가중치 업데이트
            w_local -= learning_rate * gradient
            
        print(f"  {client_name}이(가) 로컬 가중치를 학습함: {w_local:.2f} (샘플 수: {n_i})")
        
        # 집계를 위해 로컬 가중치 및 샘플 크기 저장
        local_weights.append(w_local)
        client_sample_sizes.append(n_i)
        
    # 단계 4 & 5: Federated Averaging (FedAvg)을 사용한 서버 측 집계
    weighted_sum = 0.0
    for w, n in zip(local_weights, client_sample_sizes):
        weighted_sum += w * n
        
    global_weight = weighted_sum / total_samples
    print(f"=> 서버에서 집계된 글로벌 가중치: {global_weight:.2f}\n")

print(f"연합 학습 후 최종 글로벌 모델 가중치: {global_weight:.2f}")

이 코드가 연합 학습을 나타내는 이유:

  • clients_data 딕셔너리는 격리된 데이터베이스를 나타냅니다. 서버는 이에 직접 접근할 수 없습니다.
  • 학습 루프에서 클라이언트로부터 서버로 전송되는 유일한 변수는 w_local입니다.
  • 서버는 샘플 크기(client_sample_sizes)를 기반으로 가중 평균을 수행하며, 이는 FedAvg 수학 공식을 따릅니다.

중앙 집중식 학습 vs. 연합 학습

특징 중앙 집중식 머신러닝 연합 학습
데이터 위치 중앙 클라우드/서버 분산 에지 디바이스
개인정보 보호 원시 데이터를 클라우드에 업로드 데이터가 기기 내에 유지됨
대역폭 높음 (원시 데이터셋 업로드) 낮음 (모델 가중치 업로드)
데이터 다양성 업로드된 데이터셋으로 한정됨 매우 높음 (실제 다양한 환경의 에지 데이터)
규정 준수 어려움 (GDPR/HIPAA 등 장애물 존재) 자체 규정 준수 가능 (설계 자체 반영)

보안 애드온: 개인정보 보호 및 암호화

연합 학습은 기본적으로 중앙 집중식 학습보다 안전하지만, 가중치 정보를 서버에 전송하는 것 역시 가중치를 역공학하여 학습 데이터를 재구성할 수 있다는 미세한 보안 리스크를 안고 있습니다. 이를 방지하기 위해 연합 학습은 주로 두 가지 핵심 보안 기술과 결합됩니다.

  1. 보안 집계(Secure Aggregation - SecAgg): 개별 클라이언트의 업데이트 내용을 보지 않고도 서버가 모든 로컬 업데이트의 합계를 계산할 수 있도록 하는 암호화 프로토콜입니다. 서버는 오직 집계된 결과만 볼 수 있습니다.
  2. 차분 프라이버시(Differential Privacy - DP): 업로드하기 전에 로컬 가중치에 수학적 ‘노이즈’를 임의로 추가하는 방법입니다. 이를 통해 글로벌 모델이 특정 사용자의 데이터를 역추적하거나 식별할 수 없도록 보장합니다.

실제 활용 사례

연합 학습은 오늘날 여러분의 기기에서 이미 조용히 작동하고 있습니다.

  • 구글 Gboard: 구글은 연합 학습을 활용하여 다음 단어 예측 및 검색어 제안을 학습시킵니다. 사용자의 타이핑 습관이나 언어 습관을 구글 서버로 직접 보내지 않고도 키보드가 스마트하게 학습합니다.
  • 애플 QuickType: 애플은 iPhone 자체에서 직접 오타 자동 수정 및 Siri 음성 인식 제안을 학습시키기 위해 분산 학습을 이용합니다.
  • 헬스케어 (MELLODDY 프로젝트): 주요 제약 회사들은 자사의 독점 연구 자료를 경쟁사에 노출시키지 않고도, 안전한 비공개 화학 데이터베이스에서 연합 학습을 사용하여 신약 발견 모델을 공동으로 학습시키고 있습니다.

요약

연합 학습은 AI 시스템 구축 방식의 패러다임 변화를 의미합니다. 데이터 소유권을 존중하고, 통신 비용을 최소화하며, 강력한 규제를 받는 산업에서도 AI 학습을 가능하게 합니다. 학습 과정을 기기의 에지로 가져감으로써, 우리는 개인 데이터를 외부로 내보내지 않고 원래 있는 곳에 안전하게 유지하면서도 더욱 스마트하고 개인화된 모델을 구축할 수 있습니다.


Ghaznix 블로그에서 탈중앙화 기술에 대한 더 많은 통찰을 얻어보세요 →