xData Aggregation FW
当ドキュメントでは、xData Aggregation FW について記述しています。
class Manager
Aggregator と Party の構成を管理するクラスです。
コード例:
次のように宣言します。
class TutorialManager(Manager):
AGGREGATOR = TutorialAggregator
TRAINER = TutorialTrainer
クラス変数
- AGGREGATOR:
Aggregatorクラスを指定します - TRAINER:
Trainerクラスを指定します
__init__
引数:
- agg: RemoteStore | None: Aggregator が使用するストアを指定します
- parties: list[RemoteStore]: Party 群が使用するストアを指定します
federate
次の流れで一連の処理を実行します。
- N台の Party で
federateを実行する - Aggregator で
federateを実行する
引数:
- current_round: int: これから実行する
roundを渡します。
戻り値:
- result: 未定義
コード例:
manager = TutorialManager(agg=store, parties=[store, store])
current_round = 1
manager.federate(current_round=current_round)
上記のコード例は次のコードとおおよそ等価です。
manager = TutorialManager(agg=store, parties=[store, store])
current_round = 1
aggregator = manager.get_aggregator()
for i in range(len(manager.parties)):
trainer = manager.get_trainer(i)
trainer.federate(current_round=current_round)
aggregator.federate(current_round=current_round)
get_aggregator
Aggregator インスタンスを取得します。
戻り値:
- result: Aggregator
get_trainer
Trainer インスタンスを取得します。
引数:
- index: int: 初期化時に指定したストアの添字
戻り値:
- result: Party
class Aggregator
Aggregator として、連合学習の一連の処理を実行するクラスです。
コード例:
# 集約アルゴリズムを実装します。
class TutorialAggregator(Aggregator):
TRAINER = TutorialTrainer
SERIALIZER = TutorialSerializer
def load_feedback_list(self):
model_ids = self.agg.get_latest_feedbacked_models()
if len(model_ids) == 0:
raise Exception()
for model_id in model_ids:
model = self.agg.load_model(model_id)
meta = self.agg.get_model_meta(model_id)
yield model, meta
def aggregate(self, current_round: int, feedback_list):
aggregated_model = 0
aggregated_meta = {"data_size": 0}
for model, meta in feedback_list:
aggregated_model += model
aggregated_meta["data_size"] += meta["data_size"]
return aggregated_model, aggregated_meta
federate
次の流れで一連の処理を実行します。
load_feedback_list: モデル群とそのメタデータを取得しますget_latest_feedbacked_models: 前回の集約以降にフィードバックされたモデル群を取得します。load_feedback_list内で呼ばれます。deserialize: 与えられたモデルをメモリ上に復元します(load_feedback_list内のload_model呼び出し時に呼ばれます)aggregate:load_feedback_listの結果を受け取り、集約処理を実行しますserialize:aggregateで返されたモデルを直列化しますput_aggregated_model:serializeで直列化したモデルとメタデータを自身の集約済みのモデルとしてモデルストアに格納します
引数:
- current_round: int: これから実行する
roundを渡します。
戻り値:
- result: dict
put_aggregated_modelで返されたmodel_infoを返します
load_feedback_list(抽象メソッド)
フィードバックされたモデル群を取得する処理を記述してください。 デフォルトの実装が用意されています。
戻り値:
- result: Iterable 任意の
Iterableを返してください。デフォルトは、モデル本体とモデルのメタデータのTupleを列挙します。
aggregate(抽象メソッド)
集約処理を記述してください。
引数:
- parameters:
current_round > 1の時、集約済みモデルが渡されます。そうでなければNoneです。 - current_round: int: 現在の
roundが渡されます。 - meta:
current_round > 1の時、集約済みモデルに付与されたメタデータが渡されます。そうでなければ{}です。
戻り値:
- result: Tuple[Any, dict] モデルと
dictを返してください。
class Trainer
Party として、連合学習の一連の処理を実行するクラスです。
コード例:
class TutorialTrainer(Trainer):
def train(self, parameters, current_round, meta):
if parameters is None:
parameters = 0
parameters += 1
return parameters, {"data_size": 2}
federate
次の流れで一連の処理を実行します。
request_transfer:current_round > 1の時、アップストリームに最新の集約済みモデルを要求し、転送されたモデルとして自身のモデルストアを格納しますdeserialize:current_round > 1の時、与えられたモデルをメモリ上に復元しますtrain: 実装した学習処理を実行します。current_round > 1の時、request_transferで取得したモデルが渡されますserialize:trainで返されたモデルを直列化しますput_model:serializeで直列化したモデルとメタデータを自身のストアに登録しますfeedback:put_modelで保存したモデルをフィードバックされたモデルとしてアップストリームのモデルストアに格納します
引数:
- current_round: int: これから実行する
roundを渡します。
戻り値:
- result: dict
feedbackでアップストリームから返されたmodel_infoを返します
train(抽象メソッド)
学習処理を記述してください。
引数:
- parameters:
current_round > 1の時、集約済みモデルが渡されます。そうでなければNoneです。 - current_round: int: 現在の
roundが渡されます。 - meta:
current_round > 1の時、集約済みモデルに付与されたメタデータが渡されます。そうでなければ{}です。
戻り値:
- result: Tuple[Any, dict] モデルと
dictを返してください。
class Serializer
モデル送受信時の永続化方法・復元方法を記述するクラスです。
コード例:
import torch
class PytorchSerializer:
def serialize(self, model, output_path):
torch.save(model, output_path)
return output_path
def deserialize(self, input_path, model_meta={}):
with open(input_path, 'rb') as f:
return torch.load(f.read(), map_location=torch.device('cpu'))
serialize(抽象メソッド)
モデルの永続化方法を記述します。
引数:
- model: モデル
- output_path: int: フレームワークから保存場所を指定される保存先のパス
戻り値:
- result: str 保存先のパスを返してください
deserialize(抽象メソッド)
永続化されたモデルの復元方法を記述します。
引数:
- model: モデル
- output_path: int: フレームワークから保存場所を指定される、保存先のパス
戻り値:
- result モデルを返してください。
class RemoteStore
Federated Learning Web API をバックエンドとするモデルストアです。
Federated Learning Web API のラッパー(model_store_ddcの指定を初期化時のみ行う)のように機能します。
詳しくは、 Federated Learning API一覧 を参考にしてください。
コード例:
from xdata_fl.client import Api
from xdata_agg.fw._abc import RemoteStore
class MyStore(RemoteStore):
SERIALIZER = PytorchSerializer
model_api = Api(endpoint="", apikey="", apisecret="")
model_store_ddc = "ddc:my_first_model_store"
# RemoteStoreを指定すると、model_store_ddc の指定を省略できます
with MyStore(model_api, model_store_ddc=model_store_ddc) as store:
model_info = store.put_model(model)
model = store.load_model(model_info["model_id"])
クラス変数
- SERIALIZER:
Serializerクラスを指定します(load_modelやput_model時に呼ばれます)
__init__
引数:
- api:
xdata_fl.client.Api: バックエンドを指定します - model_store_ddc: str: バックエンドのモデルストアを指定します