論文筆記(8)-"Personalized Federated Learning using Hypernetworks"
阿新 • • 發佈:2021-12-08
只能說很像計算機那批人做的,就當看code和想法吧
這篇是ICML 2021的一篇論文,論文和程式碼都看了一下,配合著程式碼簡單說一下文章思路。
Motivation
文章說PFL的難點在於用盡量少的通訊成本為每個使用者提供個性化模型。然後文章列出的主要貢獻也是傳輸成本和模型複雜度以及可以為不同算力資源的裝置提供適應大小的模型,並且在結果上取得了不錯的效果。
作者通過在Server
端訓練一個hyper net
來為各個使用者生成所需要的模型引數來實現解耦傳輸成本和模型複雜度。
Model Construction
文中的Hyper net
是一個多頭網路,每個頭輸出的都是某一層的權重Tensor
。具體而言,例如對於Cifar10
,它的Hyper Net實際是這個樣子
class CNNHyper(nn.Module): def __init__( self, n_nodes, embedding_dim, in_channels=3, out_dim=10, n_kernels=16, hidden_dim=100, spec_norm=False, n_hidden=1): ''' The hyper network stored in the server to generate the weight of the target network. Args: n_nodes: int, the total number of all nodes(users or clients) embedding_dim: int, dimension of the embedding in_channels: int, the channels of the input image or data. out_dim: int, the amount of categories n_kernels: int, the number of kernels used in CNN hidden_dim: int, the dimension of the finnal latent layer in hypernetwork spec_norm: Bool, whether apply the sepc norm n_hidden: int, the number of the latent layers ''' super().__init__() self.in_channels = in_channels self.out_dim = out_dim self.n_kernels = n_kernels self.embeddings = nn.Embedding(num_embeddings=n_nodes, embedding_dim=embedding_dim) # Multilayer perceptron layers = [ spectral_norm(nn.Linear(embedding_dim, hidden_dim)) if spec_norm else nn.Linear(embedding_dim, hidden_dim), ] for _ in range(n_hidden): layers.append(nn.ReLU(inplace=True)) layers.append( spectral_norm(nn.Linear(hidden_dim, hidden_dim)) if spec_norm else nn.Linear(hidden_dim, hidden_dim), ) self.mlp = nn.Sequential(*layers) # the weights of the targe network self.c1_weights = nn.Linear(hidden_dim, self.n_kernels * self.in_channels * 5 * 5) self.c1_bias = nn.Linear(hidden_dim, self.n_kernels) self.c2_weights = nn.Linear(hidden_dim, 2 * self.n_kernels * self.n_kernels * 5 * 5) self.c2_bias = nn.Linear(hidden_dim, 2 * self.n_kernels) self.l1_weights = nn.Linear(hidden_dim, 120 * 2 * self.n_kernels * 5 * 5) self.l1_bias = nn.Linear(hidden_dim, 120) self.l2_weights = nn.Linear(hidden_dim, 84 * 120) self.l2_bias = nn.Linear(hidden_dim, 84) self.l3_weights = nn.Linear(hidden_dim, self.out_dim * 84) self.l3_bias = nn.Linear(hidden_dim, self.out_dim) if spec_norm: self.c1_weights = spectral_norm(self.c1_weights) self.c1_bias = spectral_norm(self.c1_bias) self.c2_weights = spectral_norm(self.c2_weights) self.c2_bias = spectral_norm(self.c2_bias) self.l1_weights = spectral_norm(self.l1_weights) self.l1_bias = spectral_norm(self.l1_bias) self.l2_weights = spectral_norm(self.l2_weights) self.l2_bias = spectral_norm(self.l2_bias) self.l3_weights = spectral_norm(self.l3_weights) self.l3_bias = spectral_norm(self.l3_bias) def forward(self, idx): emd = self.embeddings(idx) features = self.mlp(emd) weights = OrderedDict({ "conv1.weight": self.c1_weights(features).view(self.n_kernels, self.in_channels, 5, 5), "conv1.bias": self.c1_bias(features).view(-1), "conv2.weight": self.c2_weights(features).view(2 * self.n_kernels, self.n_kernels, 5, 5), "conv2.bias": self.c2_bias(features).view(-1), "fc1.weight": self.l1_weights(features).view(120, 2 * self.n_kernels * 5 * 5), "fc1.bias": self.l1_bias(features).view(-1), "fc2.weight": self.l2_weights(features).view(84, 120), "fc2.bias": self.l2_bias(features).view(-1), "fc3.weight": self.l3_weights(features).view(self.out_dim, 84), "fc3.bias": self.l3_bias(features).view(-1), }) return weights
使用者端的Target Network
結構
class CNNTarget(nn.Module): def __init__(self, in_channels=3, n_kernels=16, out_dim=10): super(CNNTarget, self).__init__() self.conv1 = nn.Conv2d(in_channels, n_kernels, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(n_kernels, 2 * n_kernels, 5) self.fc1 = nn.Linear(2 * n_kernels * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, out_dim) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(x.shape[0], -1) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x
Optimization
然後直接看優化流程吧,對於我比較掛心的使用者的特徵向量\(\mathcal{v}_i\),他是直接拿使用者的node_id
也就是使用者的標號,embedding出來的。整個程式碼只有兩個model的例項,分別就是Hyper Network
和Target Network
的,然後每一輪只選擇一個使用者,Target Network
載入根據node_id
embedding出來的特徵向量計算得來的權重,並進行優化。
def train(data_name: str, data_path: str, classes_per_node: int, num_nodes: int,
steps: int, inner_steps: int, optim: str, lr: float, inner_lr: float,
embed_lr: float, wd: float, inner_wd: float, embed_dim: int, hyper_hid: int,
n_hidden: int, n_kernels: int, bs: int, device, eval_every: int, save_path: Path,
seed: int) -> None:
'''
The optimization process
Arg:
data_name: str, [Cifar10 or Cifar100]
data_path: the path of the data
classes_per_node: int, the number of classes chosen by each node
num_nodes: int, the total number of nodes or users
steps: int, the number of conmmunication rounds
inner_steps: int, the number of local graidnet steps
optim: str, sgd or adam
lr: float, learning rate of the server
inner_lr: float, learning rate of the node
embed_lr: float, learning rate of the embedding layer
wd: float, weight decay of the server
inner_wd: float, weight decay of the node
embed_dim: int, the dimension of the embedding layer output
hyper_hid: int, the dimension of the finnal hidden layer output
n_hidden: int, the number of latent layers
n_kernels: int, the number of kernnels in CNN
bs: int, batch_size
'''
###############################
# init nodes, hnet, local net #
###############################
nodes = BaseNodes(data_name, data_path, num_nodes, classes_per_node=classes_per_node,
batch_size=bs)
# setting the embedding dim according to the n_nodes
embed_dim = embed_dim
if embed_dim == -1:
logging.info("auto embedding size")
embed_dim = int(1 + num_nodes / 4)
# Build the model
if data_name == "cifar10":
hnet = CNNHyper(num_nodes, embed_dim, hidden_dim=hyper_hid, n_hidden=n_hidden, n_kernels=n_kernels)
net = CNNTarget(n_kernels=n_kernels)
elif data_name == "cifar100":
hnet = CNNHyper(num_nodes, embed_dim, hidden_dim=hyper_hid,
n_hidden=n_hidden, n_kernels=n_kernels, out_dim=100)
net = CNNTarget(n_kernels=n_kernels, out_dim=100)
else:
raise ValueError("choose data_name from ['cifar10', 'cifar100']")
hnet = hnet.to(device)
net = net.to(device)
##################
# init optimizer #
##################
embed_lr = embed_lr if embed_lr is not None else lr
optimizers = {
'sgd': torch.optim.SGD(
[
{'params': [p for n, p in hnet.named_parameters() if 'embed' not in n]},
{'params': [p for n, p in hnet.named_parameters() if 'embed' in n], 'lr': embed_lr}
], lr=lr, momentum=0.9, weight_decay=wd
),
'adam': torch.optim.Adam(params=hnet.parameters(), lr=lr)
}
optimizer = optimizers[optim]
criteria = torch.nn.CrossEntropyLoss()
################
# init metrics #
################
last_eval = -1
best_step = -1
best_acc = -1
test_best_based_on_step, test_best_min_based_on_step = -1, -1
test_best_max_based_on_step, test_best_std_based_on_step = -1, -1
step_iter = trange(steps)
results = defaultdict(list)
for step in step_iter:
hnet.train()
# select a client at random
node_id = random.choice(range(num_nodes))
# produce & load local network weights
weights = hnet(torch.tensor([node_id], dtype=torch.long).to(device))
net.load_state_dict(weights)
# init inner optimizer
inner_optim = torch.optim.SGD(
net.parameters(), lr=inner_lr, momentum=.9, weight_decay=inner_wd
)
# storing theta_i for later calculating delta theta
inner_state = OrderedDict({k: tensor.data for k, tensor in weights.items()})
# NOTE: evaluation on sent model
with torch.no_grad():
net.eval()
batch = next(iter(nodes.test_loaders[node_id]))
img, label = tuple(t.to(device) for t in batch)
pred = net(img)
prvs_loss = criteria(pred, label)
prvs_acc = pred.argmax(1).eq(label).sum().item() / len(label)
net.train()
# inner updates -> obtaining theta_tilda
for i in range(inner_steps):
net.train()
inner_optim.zero_grad()
optimizer.zero_grad()
batch = next(iter(nodes.train_loaders[node_id]))
img, label = tuple(t.to(device) for t in batch)
pred = net(img)
loss = criteria(pred, label)
loss.backward()
torch.nn.utils.clip_grad_norm_(net.parameters(), 50)
inner_optim.step()
optimizer.zero_grad()
final_state = net.state_dict()
# calculating delta theta
delta_theta = OrderedDict({k: inner_state[k] - final_state[k] for k in weights.keys()})
# calculating phi gradient
hnet_grads = torch.autograd.grad(
list(weights.values()), hnet.parameters(), grad_outputs=list(delta_theta.values())
)
# update hnet weights
for p, g in zip(hnet.parameters(), hnet_grads):
p.grad = g
torch.nn.utils.clip_grad_norm_(hnet.parameters(), 50)
optimizer.step()
step_iter.set_description(
f"Step: {step+1}, Node ID: {node_id}, Loss: {prvs_loss:.4f}, Acc: {prvs_acc:.4f}"
)
# evaluation
if step % eval_every == 0:
last_eval = step
step_results, avg_loss, avg_acc, all_acc = eval_model(nodes, num_nodes, hnet, net, criteria, device, split="test")
logging.info(f"\nStep: {step+1}, AVG Loss: {avg_loss:.4f}, AVG Acc: {avg_acc:.4f}")
results['test_avg_loss'].append(avg_loss)
results['test_avg_acc'].append(avg_acc)
_, val_avg_loss, val_avg_acc, _ = eval_model(nodes, num_nodes, hnet, net, criteria, device, split="val")
if best_acc < val_avg_acc:
best_acc = val_avg_acc
best_step = step
test_best_based_on_step = avg_acc
test_best_min_based_on_step = np.min(all_acc)
test_best_max_based_on_step = np.max(all_acc)
test_best_std_based_on_step = np.std(all_acc)
results['val_avg_loss'].append(val_avg_loss)
results['val_avg_acc'].append(val_avg_acc)
results['best_step'].append(best_step)
results['best_val_acc'].append(best_acc)
results['best_test_acc_based_on_val_beststep'].append(test_best_based_on_step)
results['test_best_min_based_on_step'].append(test_best_min_based_on_step)
results['test_best_max_based_on_step'].append(test_best_max_based_on_step)
results['test_best_std_based_on_step'].append(test_best_std_based_on_step)
if step != last_eval:
_, val_avg_loss, val_avg_acc, _ = eval_model(nodes, num_nodes, hnet, net, criteria, device, split="val")
step_results, avg_loss, avg_acc, all_acc = eval_model(nodes, num_nodes, hnet, net, criteria, device, split="test")
logging.info(f"\nStep: {step + 1}, AVG Loss: {avg_loss:.4f}, AVG Acc: {avg_acc:.4f}")
results['test_avg_loss'].append(avg_loss)
results['test_avg_acc'].append(avg_acc)
if best_acc < val_avg_acc:
best_acc = val_avg_acc
best_step = step
test_best_based_on_step = avg_acc
test_best_min_based_on_step = np.min(all_acc)
test_best_max_based_on_step = np.max(all_acc)
test_best_std_based_on_step = np.std(all_acc)
results['val_avg_loss'].append(val_avg_loss)
results['val_avg_acc'].append(val_avg_acc)
results['best_step'].append(best_step)
results['best_val_acc'].append(best_acc)
results['best_test_acc_based_on_val_beststep'].append(test_best_based_on_step)
results['test_best_min_based_on_step'].append(test_best_min_based_on_step)
results['test_best_max_based_on_step'].append(test_best_max_based_on_step)
results['test_best_std_based_on_step'].append(test_best_std_based_on_step)
save_path = Path(save_path)
save_path.mkdir(parents=True, exist_ok=True)
with open(str(save_path / f"results_{inner_steps}_inner_steps_seed_{seed}.json"), "w") as file:
json.dump(results, file, indent=4)
Summary
- 使用者特徵\(\mathcal{v}_i\)的獲取是最讓我感到奇怪的,可能用
embedding
來生成很直接,但是放在server
端去根據node_id
生成就有一種server
和node
是一種對抗的感覺,明明使用者有自己的例如人口特徵等使用者特徵資料。感覺讓所有使用者用這些資料去產生一個\(v_i\)更符合邏輯; - 關於他說的傳輸成本和模型複雜度的解耦,感覺說的模稜兩可,他傳輸的資料和普通的
FedAvg
是一樣的,他確實可以在server端訓練一個很深的網路,但是使用者本地的模型變複雜那他的傳輸成本也會提高; - 提供的程式碼裡沒有展示對不同算力資源的裝置生成不同的模型。