xData Aggregation FW: API 一覧
当ドキュメントでは、xData Aggregation FW について記述しています。
class Manager
Aggregator と Party の構成を管理するクラスです。
コード例:
次のように宣言します。
class TutorialManager(Manager):
AGGREGATOR = TutorialAggregator
TRAINER = TutorialTrainer
クラス変数
- AGGREGATOR:
Aggregator
クラスを指定します - TRAINER:
Trainer
クラスを指定します
__init__
引数:
- agg: RemoteStore | None: Aggregator が使用するストアを指定します
- parties: list[RemoteStore]: Party 群が使用するストアを指定します
federate
次の流れで一連の処理を実行します。
- N台の Party で
federate
を実行する - Aggregator で
federate
を実行する
引数:
- current_round: int: これから実行する
round
を渡します。
戻り値:
- result: 未定義
コード例:
manager = TutorialManager(agg=store, parties=[store, store])
current_round = 1
manager.federate(current_round=current_round)
上記のコード例は次のコードとおおよそ等価です。
manager = TutorialManager(agg=store, parties=[store, store])
current_round = 1
aggregator = manager.get_aggregator()
for i in range(len(manager.parties)):
trainer = manager.get_trainer(i)
trainer.federate(current_round=current_round)
aggregator.federate(current_round=current_round)
get_aggregator
Aggregator
インスタンスを取得します。
戻り値:
- result: Aggregator
get_trainer
Trainer
インスタンスを取得します。
引数:
- index: int: 初期化時に指定したストアの添字
戻り値:
- result: Party
class Aggregator
Aggregator として、連合学習の一連の処理を実行するクラスです。
コード例:
# 集約アルゴリズムを実装します。
class TutorialAggregator(Aggregator):
TRAINER = TutorialTrainer
SERIALIZER = TutorialSerializer
def load_feedback_list(self):
model_ids = self.agg.get_latest_feedbacked_models()
if len(model_ids) == 0:
raise Exception()
for model_id in model_ids:
model = self.agg.load_model(model_id)
meta = self.agg.get_model_meta(model_id)
yield model, meta
def aggregate(self, current_round: int, feedback_list):
aggregated_model = 0
aggregated_meta = {"data_size": 0}
for model, meta in feedback_list:
aggregated_model += model
aggregated_meta["data_size"] += meta["data_size"]
return aggregated_model, aggregated_meta
federate
次の流れで一連の処理を実行します。
load_feedback_list
: モデル群とそのメタデータを取得しますget_latest_feedbacked_models
: 前回の集約以降にフィードバックされたモデル群を取得します。load_feedback_list
内で呼ばれます。deserialize
: 与えられたモデルをメモリ上に復元します(load_feedback_list
内のload_model
呼び出し時に呼ばれます)aggregate
:load_feedback_list
の結果を受け取り、集約処理を実行しますserialize
:aggregate
で返されたモデルを直列化しますput_aggregated_model
:serialize
で直列化したモデルとメタデータを自身の集約済みのモデルとしてモデルストアに格納します
引数:
- current_round: int: これから実行する
round
を渡します。
戻り値:
- result: dict
put_aggregated_model
で返されたmodel_info
を返します
load_feedback_list
(抽象メソッド)
フィードバックされたモデル群を取得する処理を記述してください。 デフォルトの実装が用意されています。
戻り値:
- result: Iterable 任意の
Iterable
を返してください。デフォルトは、モデル本体とモデルのメタデータのTuple
を列挙します。
aggregate
(抽象メソッド)
集約処理を記述してください。
引数:
- parameters:
current_round > 1
の時、集約済みモデルが渡されます。そうでなければNone
です。 - current_round: int: 現在の
round
が渡されます。 - meta:
current_round > 1
の時、集約済みモデルに付与されたメタデータが渡されます。そうでなければ{}
です。
戻り値:
- result: Tuple[Any, dict] モデルと
dict
を返してください。
class Trainer
Party として、連合学習の一連の処理を実行するクラスです。
コード例:
class TutorialTrainer(Trainer):
def train(self, parameters, current_round, meta):
if parameters is None:
parameters = 0
parameters += 1
return parameters, {"data_size": 2}
federate
次の流れで一連の処理を実行します。
request_transfer
:current_round > 1
の時、アップストリームに最新の集約済みモデルを要求し、転送されたモデルとして自身のモデルストアを格納しますdeserialize
:current_round > 1
の時、与えられたモデルをメモリ上に復元しますtrain
: 実装した学習処理を実行します。current_round > 1
の時、request_transfer
で取得したモデルが渡されますserialize
:train
で返されたモデルを直列化しますput_model
:serialize
で直列化したモデルとメタデータを自身のストアに登録しますfeedback
:put_model
で保存したモデルをフィードバックされたモデルとしてアップストリームのモデルストアに格納します
引数:
- current_round: int: これから実行する
round
を渡します。
戻り値:
- result: dict
feedback
でアップストリームから返されたmodel_info
を返します
train
(抽象メソッド)
学習処理を記述してください。
引数:
- parameters:
current_round > 1
の時、集約済みモデルが渡されます。そうでなければNone
です。 - current_round: int: 現在の
round
が渡されます。 - meta:
current_round > 1
の時、集約済みモデルに付与されたメタデータが渡されます。そうでなければ{}
です。
戻り値:
- result: Tuple[Any, dict] モデルと
dict
を返してください。
class Serializer
モデル送受信時の永続化方法・復元方法を記述するクラスです。
コード例:
import torch
class PytorchSerializer:
def serialize(self, model, output_path):
torch.save(model, output_path)
return output_path
def deserialize(self, input_path, model_meta={}):
with open(input_path, 'rb') as f:
return torch.load(f.read(), map_location=torch.device('cpu'))
serialize
(抽象メソッド)
モデルの永続化方法を記述します。
引数:
- model: モデル
- output_path: int: フレームワークから保存場所を指定される保存先のパス
戻り値:
- result: str 保存先のパスを返してください
deserialize
(抽象メソッド)
永続化されたモデルの復元方法を記述します。
引数:
- model: モデル
- output_path: int: フレームワークから保存場所を指定される、保存先のパス
戻り値:
- result モデルを返してください。
class RemoteStore
Federated Learning Web API をバックエンドとするモデルストアです。
Federated Learning Web API のラッパー(model_store_ddc
の指定を初期化時のみ行う)のように機能します。
詳しくは、 Federated Learning API一覧 を参考にしてください。
コード例:
from xdata_fl.client import Api
from xdata_agg.fw._abc import RemoteStore
class MyStore(RemoteStore):
SERIALIZER = PytorchSerializer
model_api = Api(endpoint="", apikey="", apisecret="")
model_store_ddc = "ddc:my_first_model_store"
# RemoteStoreを指定すると、model_store_ddc の指定を省略できます
with MyStore(model_api, model_store_ddc=model_store_ddc) as store:
model_info = store.put_model(model)
model = store.load_model(model_info["model_id"])
クラス変数
- SERIALIZER:
Serializer
クラスを指定します(load_model
やput_model
時に呼ばれます)
__init__
引数:
- api:
xdata_fl.client.Api
: バックエンドを指定します - model_store_ddc: str: バックエンドのモデルストアを指定します