Skip to content

xData Aggregation FW: API 一覧

当ドキュメントでは、xData Aggregation FW について記述しています。

class Manager

Aggregator と Party の構成を管理するクラスです。

コード例:

次のように宣言します。

class TutorialManager(Manager):
    AGGREGATOR = TutorialAggregator
    TRAINER = TutorialTrainer

クラス変数

  • AGGREGATOR: Aggregator クラスを指定します
  • TRAINER: Trainer クラスを指定します

__init__

引数:

  • agg: RemoteStore | None: Aggregator が使用するストアを指定します
  • parties: list[RemoteStore]: Party 群が使用するストアを指定します

federate

次の流れで一連の処理を実行します。

  1. N台の Party で federate を実行する
  2. Aggregator で federate を実行する

引数:

  • current_round: int: これから実行するroundを渡します。

戻り値:

  • result: 未定義

コード例:

manager = TutorialManager(agg=store, parties=[store, store])
current_round = 1

manager.federate(current_round=current_round)

上記のコード例は次のコードとおおよそ等価です。

manager = TutorialManager(agg=store, parties=[store, store])
current_round = 1

aggregator = manager.get_aggregator()

for i in range(len(manager.parties)):
  trainer = manager.get_trainer(i)
  trainer.federate(current_round=current_round)

aggregator.federate(current_round=current_round)

get_aggregator

Aggregator インスタンスを取得します。

戻り値:

  • result: Aggregator

get_trainer

Trainer インスタンスを取得します。

引数:

  • index: int: 初期化時に指定したストアの添字

戻り値:

  • result: Party

class Aggregator

Aggregator として、連合学習の一連の処理を実行するクラスです。

コード例:

# 集約アルゴリズムを実装します。
class TutorialAggregator(Aggregator):
    TRAINER = TutorialTrainer
    SERIALIZER = TutorialSerializer

    def load_feedback_list(self):
        model_ids = self.agg.get_latest_feedbacked_models()

        if len(model_ids) == 0:
            raise Exception()

        for model_id in model_ids:
            model = self.agg.load_model(model_id)
            meta = self.agg.get_model_meta(model_id)
            yield model, meta

    def aggregate(self, current_round: int, feedback_list):
        aggregated_model = 0
        aggregated_meta = {"data_size": 0}

        for model, meta in feedback_list:
            aggregated_model += model
            aggregated_meta["data_size"] += meta["data_size"]

        return aggregated_model, aggregated_meta

federate

次の流れで一連の処理を実行します。

  1. load_feedback_list: モデル群とそのメタデータを取得します
  2. get_latest_feedbacked_models: 前回の集約以降にフィードバックされたモデル群を取得します。load_feedback_list内で呼ばれます。
  3. deserialize: 与えられたモデルをメモリ上に復元します(load_feedback_list内のload_model呼び出し時に呼ばれます)
  4. aggregate: load_feedback_listの結果を受け取り、集約処理を実行します
  5. serialize: aggregate で返されたモデルを直列化します
  6. put_aggregated_model: serializeで直列化したモデルとメタデータを自身の集約済みのモデルとしてモデルストアに格納します

引数:

  • current_round: int: これから実行するroundを渡します。

戻り値:

  • result: dict put_aggregated_modelで返されたmodel_infoを返します

load_feedback_list(抽象メソッド)

フィードバックされたモデル群を取得する処理を記述してください。 デフォルトの実装が用意されています。

戻り値:

  • result: Iterable 任意のIterableを返してください。デフォルトは、モデル本体とモデルのメタデータのTupleを列挙します。

aggregate(抽象メソッド)

集約処理を記述してください。

引数:

  • parameters: current_round > 1の時、集約済みモデルが渡されます。そうでなければNoneです。
  • current_round: int: 現在のroundが渡されます。
  • meta: current_round > 1の時、集約済みモデルに付与されたメタデータが渡されます。そうでなければ{}です。

戻り値:

  • result: Tuple[Any, dict] モデルとdictを返してください。

class Trainer

Party として、連合学習の一連の処理を実行するクラスです。

コード例:

class TutorialTrainer(Trainer):
    def train(self, parameters, current_round, meta):
        if parameters is None:
            parameters = 0

        parameters += 1
        return parameters, {"data_size": 2}

federate

次の流れで一連の処理を実行します。

  1. request_transfer: current_round > 1 の時、アップストリームに最新の集約済みモデルを要求し、転送されたモデルとして自身のモデルストアを格納します
  2. deserialize: current_round > 1 の時、与えられたモデルをメモリ上に復元します
  3. train: 実装した学習処理を実行します。current_round > 1 の時、request_transferで取得したモデルが渡されます
  4. serialize: train で返されたモデルを直列化します
  5. put_model: serializeで直列化したモデルとメタデータを自身のストアに登録します
  6. feedback: put_modelで保存したモデルをフィードバックされたモデルとしてアップストリームのモデルストアに格納します

引数:

  • current_round: int: これから実行するroundを渡します。

戻り値:

  • result: dict feedbackでアップストリームから返されたmodel_infoを返します

train(抽象メソッド)

学習処理を記述してください。

引数:

  • parameters: current_round > 1の時、集約済みモデルが渡されます。そうでなければNoneです。
  • current_round: int: 現在のroundが渡されます。
  • meta: current_round > 1の時、集約済みモデルに付与されたメタデータが渡されます。そうでなければ{}です。

戻り値:

  • result: Tuple[Any, dict] モデルとdictを返してください。

class Serializer

モデル送受信時の永続化方法・復元方法を記述するクラスです。

コード例:

import torch

class PytorchSerializer:
    def serialize(self, model, output_path):
        torch.save(model, output_path)
        return output_path

    def deserialize(self, input_path, model_meta={}):
        with open(input_path, 'rb') as f:
            return torch.load(f.read(), map_location=torch.device('cpu'))

serialize(抽象メソッド)

モデルの永続化方法を記述します。

引数:

  • model: モデル
  • output_path: int: フレームワークから保存場所を指定される保存先のパス

戻り値:

  • result: str 保存先のパスを返してください

deserialize(抽象メソッド)

永続化されたモデルの復元方法を記述します。

引数:

  • model: モデル
  • output_path: int: フレームワークから保存場所を指定される、保存先のパス

戻り値:

  • result モデルを返してください。

class RemoteStore

Federated Learning Web API をバックエンドとするモデルストアです。 Federated Learning Web API のラッパー(model_store_ddcの指定を初期化時のみ行う)のように機能します。

詳しくは、 Federated Learning API一覧 を参考にしてください。

コード例:

from xdata_fl.client import Api
from xdata_agg.fw._abc import RemoteStore

class MyStore(RemoteStore):
    SERIALIZER = PytorchSerializer

model_api = Api(endpoint="", apikey="", apisecret="")
model_store_ddc = "ddc:my_first_model_store"

# RemoteStoreを指定すると、model_store_ddc の指定を省略できます
with MyStore(model_api, model_store_ddc=model_store_ddc) as store:
    model_info = store.put_model(model)
    model = store.load_model(model_info["model_id"])

クラス変数

  • SERIALIZER: Serializer クラスを指定します(load_modelput_model時に呼ばれます)

__init__

引数:

  • api: xdata_fl.client.Api: バックエンドを指定します
  • model_store_ddc: str: バックエンドのモデルストアを指定します