-
Notifications
You must be signed in to change notification settings - Fork 363
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【OSCP】 在 SecretFlow 中添加基于torch后端的fed_pac策略 #1276
base: main
Are you sure you want to change the base?
Conversation
add fedpac server
CLA Assistant Lite bot All contributors have signed the CLA ✍️ ✅ |
I have read the CLA Document and I hereby sign the CLA |
1 similar comment
I have read the CLA Document and I hereby sign the CLA |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加个单测,参考 tests/ml/nn/fl/test_fl_model_torch.py
from torch.utils.data import DataLoader | ||
|
||
|
||
def cifar10(stage='train'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
数据集处理写到单测,或者测试脚本
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 挪到单测里使用custom databuilder模式,或者直接使用datasets里面的逻辑secretflow/utils/simulation/datasets.py
@@ -49,6 +51,7 @@ def __init__( | |||
self.train_set = None | |||
self.eval_set = None | |||
self.skip_bn = skip_bn | |||
# self.dataset_size = torch.tensor(len(self.train_set.dataset)).to(self.exe_device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
无用代码删掉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -214,11 +217,11 @@ def next_batch(self, stage="train"): | |||
def get_rows_count(self, filename): | |||
return int(rows_count(filename=filename)) - 1 # except header line | |||
|
|||
def get_weights(self, return_numpy=True): | |||
def get_weights(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为什么删掉参数 return_numpy ?
self.num_classes = kwargs.get("num_classes", 10) | ||
self.criterion = nn.CrossEntropyLoss() | ||
self.local_model = self.model | ||
self.last_model = deepcopy(self.model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.last_model 用来?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是照搬源码的,源码是在本地训练开始之前保存的模型的深度副本,可能是考虑到了回滚,但是源码中也没有明确使用 self.last_model,现在已经删除掉这行了。
from copy import deepcopy | ||
|
||
|
||
class FedPACTorchModel(BaseTorchModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个类可以合到 strategy/fed_pac.py FedPAC
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
lr = kwargs.get("lr", 0.01) | ||
|
||
epoch_classifier = 1 | ||
optimizer = torch.optim.SGD( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
optimizer loss 之类的可以写到 TorchModel 里吗
momentum=0.5, | ||
weight_decay=0.0005, | ||
) | ||
for ep in range(epoch_classifier): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
每个 step 训练一个 epoch 吗
from secretflow.security.aggregation.aggregator import Aggregator | ||
|
||
|
||
class FedPACAggregator(Aggregator): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
average sum 逻辑没变的话可以继承 PlainAggregator
add fed_pac to init
@@ -216,9 +219,9 @@ def get_rows_count(self, filename): | |||
|
|||
def get_weights(self, return_numpy=True): | |||
if self.skip_bn: | |||
return self.model.get_weights_not_bn(return_numpy=return_numpy) | |||
return self.model.get_weights_not_bn(return_numpy=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里 return_numpy 为什么写死?
@@ -23,7 +23,6 @@ | |||
|
|||
from .mixins import ParametersMixin | |||
|
|||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
格式化配置一下,这里不要做修改
return dataset | ||
|
||
|
||
def batch_sampler( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
batch_sampler已有实现,直接调用即可secretflow/ml/nn/fl/backend/torch/sampler.py
from torch.utils.data import DataLoader | ||
|
||
|
||
def cifar10(stage='train'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 挪到单测里使用custom databuilder模式,或者直接使用datasets里面的逻辑secretflow/utils/simulation/datasets.py
Type of change
Description