1. 程式人生 > 其它 >論文筆記(9)-"Personalized Federated Learning with Gaussian Processes"

論文筆記(9)-"Personalized Federated Learning with Gaussian Processes"

不要對這篇blog有什麼期待,因為不論是paper還是code我都看不懂,尤其是程式碼。就當翻譯記錄吧,看看作者幹了啥事,有無實現設定的目標

Personalized Federated Learning with Gaussian Processes

這篇blog不會涉及任何實現細節(因為我沒看懂),也不會講任何該方法的advantages(因為我也沒看懂他到底怎麼novel),只會說一說這篇文章幹了什麼事,總之會是一個很朦朧的blog(就裝自己懂了吧)。

Motivation

這篇文章它自己提的motivation是“learn effectively across clients even though each client has unique data that is often limited in size”,大致意思就是如何在少量樣本下建立一個PFL。然後作者就想到高斯過程(GP)在少樣本條件下表現的很好,就想把GP搬到FL裡。

Challenges & solutions

non-Gaussian in classification problem

因為FL很多都是分類問題,而在該類問題上得到的marginal distribution不是高斯分佈。作者就提出引入服從Pólya-Gamma augmentation分佈的變數\(\omega\)來解決。

其中\(g_k\sim Gamma(b,1)\)\(\omega\)滿足這樣的性質

似然可以寫成這樣形式:

後驗是這樣的:

Multiclass classification

上面這種Pólya-Gamma augmentation僅適用於二分類的情況,對於Cifar10

Cifar100這種多分類的問題,是不適用的。但是多分類問題可以劃分為多個二分類問題,因此作者提出構建一個GP-tree,例如對於Cifar10資料這個GP-tree中應該有10個葉子節點,每一個葉子節點對應一個圖片類別。樹種的所有非葉子節點都對應一個Pólya-Gamma augmentation的GP。

在文中,作者是通過Kmeans或者Hierarchy cluster來構建樹的,具體可以看下程式碼:

class Split(object):
    # split the node into two or more branches
    # the base class
    def __init__(self, labels, branches=3):
        self.old_to_new = {}
        self.labels = labels
        self.classes = np.unique(labels)
        self.num_classes = self.classes.shape[0]
        self.branches = branches

    def split(self, *args, **kwargs):
        if self.num_classes == 3:
            self.old_to_new[self.classes[0]] = 0
            self.old_to_new[self.classes[1]] = 1
            self.old_to_new[self.classes[2]] = 2
        elif self.num_classes == 2:
            self.old_to_new[self.classes[0]] = 0
            self.old_to_new[self.classes[1]] = 1
        else:
            self.old_to_new[self.classes[0]] = 0
        return self.old_to_new

class ProtoTypeSplit(Split):
    """
    split labels associated with a node to x branches by the prototype of each class.
    close classes should be grouped together
    :param labels: numpy array of the labels
    :param branches: the number of branches
    :param prototype: dictionary of {label: np.array()}
    :param affinity: Metric - “euclidean”, “l1”, “l2”, “manhattan”, “cosine”
    :param linkage: Distance to use between sets of observation: “ward”, “complete”, “average”, “single”
    :return the original classes partitioned to nodes
    """
    def __init__(self, labels, branches, prototype, affinity='cosine', linkage='complete'):
        super().__init__(labels, branches)
        self.affinity = affinity
        self.linkage = linkage
        self.prototype = prototype

    def split(self):

        # hierarchical clustreing
        n_clusters = self.branches
        clustering = AgglomerativeClustering(n_clusters=n_clusters, affinity=self.affinity, linkage=self.linkage)
        lbl_assignment = clustering.fit(list(self.prototype.values())).labels_

        for o, n in zip(self.prototype.keys(), lbl_assignment):
            self.old_to_new.update({o: n.item()})

        return self.old_to_new

class MeanSplitAgglomerative(Split):
    """
    split labels associated with a node to x branches by the mean vector of each class.
    close classes should be grouped together
    :param labels: numpy array of the labels
    :param branches: the number of branches
    :param data: numpy array of the data
    :param affinity: Metric - “euclidean”, “l1”, “l2”, “manhattan”, “cosine”
    :param linkage: Distance to use between sets of observation: “ward”, “complete”, “average”, “single”
    :return the original classes partitioned to nodes
    """
    def __init__(self, labels, branches, data, affinity='euclidean', linkage='ward'):
        super().__init__(labels, branches)
        self.affinity = affinity
        self.linkage = linkage
        self.data = data

    def split(self):

        # mean vector of each class
        means = np.array([0])
        for idx, i in enumerate(self.classes):
            tmp = self.data[np.where(self.labels == i)]
            mean_vec = np.mean(tmp, axis=0, keepdims=True)
            means = mean_vec if idx == 0 else np.concatenate((means, mean_vec), axis=0)

        # hierarchical clustreing
        n_clusters = self.branches
        clustering = AgglomerativeClustering(n_clusters=n_clusters, affinity=self.affinity, linkage=self.linkage)
        lbl_assignment = clustering.fit(means).labels_

        for o, n in zip(self.classes, lbl_assignment):
            self.old_to_new.update({o.item(): n.item()})

        return self.old_to_new
class BinaryTreepFedGPIPData(BinaryTree):

    def __init__(self, args, device):
        super(BinaryTreepFedGPIPData, self).__init__(args, device)
        self.root = NodepFedGPIPData()
        self.root.id = 0
        self.root.depth = 0

    def build_tree(self, root, X, Y, X_bar):
        """
        Build binary tree with GP attached to each node
        """
        # root
        q = deque()

        # push source vertex into the queue
        q.append((root, X, Y))
        curr_id = 1
        gp_counter = 0  # for getting avg. loss over the whole tree

        # loop till queue is empty
        while q:
            # pop front node from queue
            root, root_X, root_Y = q.popleft()
            node_classes, _ = torch.sort(torch.unique(root_Y))
            num_classes = node_classes.size(0)

            # Xbar's of current node
            X_bar_root = X_bar[node_classes, ...]

            # two classes or less - no heuristic for splitting
            split_method = 'MeanSplitKmeans' if num_classes > 2 else 'Split'
            root_old_to_new = \
                self.split_func(detach_to_numpy(root_X),
                                detach_to_numpy(root_Y))[split_method].split()

            root.set_data(root_Y, root_old_to_new)

            # build label vector of current node
            num_Xbars = X_bar_root.shape[1]
            i = 0
            for original_lbl, node_lbl in root_old_to_new.items():
                Y_bar_class = torch.zeros(num_Xbars, device=Y.device, dtype=Y.dtype) if node_lbl == 0 \
                    else torch.ones(num_Xbars, device=Y.device, dtype=Y.dtype)
                Y_bar_root = Y_bar_class if i == 0 else torch.cat((Y_bar_root, Y_bar_class))
                i += 1

            # leaf node
            if num_classes == 1:
                # logging.info('Reached a leaf node. Node index: ' + str(root.id) + ' ')
                continue

            # Internal node
            else:
                gp_counter += 1
                root.set_model(self.args.kernel_function,
                               self.args.num_gibbs_steps_train, self.args.num_gibbs_draws_train,
                               self.args.num_gibbs_steps_test, self.args.num_gibbs_draws_test,
                               self.args.outputscale_increase, self.args.outputscale,
                               self.args.lengthscale, Y_bar_root, self.args.balance_classes)

                left_X, left_Y = pytorch_take(root_X, root_Y, root.new_to_old[0])
                right_X, right_Y = pytorch_take(root_X, root_Y, root.new_to_old[1])
                child_X = [left_X, right_X]
                child_Y = [left_Y, right_Y]

                branches = 2
                for i in range(branches):
                    child = NodepFedGPIPData()
                    child.id = curr_id
                    curr_id += 1
                    child.depth = root.depth + 1
                    root.set_child(child, i)
                    q.append((child, child_X[i], child_Y[i]))

        return gp_counter

那麼對於類別\(t\)的資料,它的似然函式為

其中\(P^{t}\)是其經過的路徑(在程式碼中通過old_to_new來標註),\(v\)是對應的節點。得到的幾個後驗分佈為

Kernel function

對於一些圖片、聲音等資料,作者通過DL embedding出一個向量來作為文中的RBF kernel或者Linear kernel等核函式的輸入。使用者\(c\)對DL引數的優化過程為

\[\begin{align} \nabla\mathcal{L}_c^{ML}(\theta;D_c)&=\sum_v\nabla\log p_\theta(\mathbf{y}_v\vert \mathbf{X}_v)\\ &= \sum_v \frac{\nabla p_\theta(\mathbf{y}_v\vert \mathbf{X}_v)}{p_\theta(\mathbf{y}_v\vert \mathbf{X}_v)}\\ &= \sum_v \frac{\nabla \int_{\omega} p_\theta(\mathbf{y}_v, \omega\vert \mathbf{X}_v)\,d\omega}{p_\theta(\mathbf{y}_v\vert \mathbf{X}_v)}\\ &= \sum_v \int \frac{\nabla p_\theta(\mathbf{y}_v, \omega\vert \mathbf{X}_v)}{p_\theta(\mathbf{y}_v\vert \mathbf{X}_v)}\,d\omega\\ &= \sum_v \int \frac{p_{\theta}(\mathbf{y}_v,\omega\vert \mathbf{X}_v)}{p_\theta(\mathbf{y}_v\vert \mathbf{X}_v)}\nabla\log p_\theta(\mathbf{y}_v,\omega\vert \mathbf{X}_v)\, d\omega\\ &= \sum_v \int p_{\theta}(\omega\vert \mathbf{y}_v, \mathbf{X}_v)\nabla(\log p_\theta(\mathbf{y}_v\vert\omega,\mathbf{X}_v)+\log p(\omega\vert \mathbf{X}_v))\, d\omega\\ &= \sum_v \int p_{\theta}(\omega\vert \mathbf{y}_v, \mathbf{X}_v)\nabla\log p_\theta(\mathbf{y}_v\vert\omega,\mathbf{X}_v)\, d\omega\\ \end{align} \]

Limitied data size

文中是通過廣播一組common的資料集來幫助資料量比較小的使用者來構建模型的(具體怎麼操作看不懂)。

Computational constraint

因為GP裡面要求逆,通常是樣本數量\(N\)\(\mathcal{O}(N^3)\)。作者通過上述的common dataset來簡化複雜度為\(\mathcal{O}(M^3)\),其中\(M\)為common dataset的資料集大小。(具體怎麼簡化的,我感覺就是求逆的時候換了個位置,用common dataset作為訓練集)

Summary

厚著臉皮來寫個summary吧,

  • 作者說要為資料量不足的使用者也構建個性化模型,然後就想到了在少樣本情況下表現也不錯的GP。按作者的話,整個系統學的是一個kernel function前的DL網路,這個網路是所有使用者共享的。
  • 作者解決limited data size和compuitational constraint的方法都是通過一個common dataset(文中叫做inducing points),然後把其當作trainning set。怎麼說呢,給我的感覺並不是從方法上進行了創新。整個文章的邏輯像是這個樣子:GP在樣本少的時候表現很好\(\rightarrow\)可以拿來做\(PFL\);使用者資料量小\(\rightarrow\)我給他廣播一組資料當訓練集還可以解決求逆過程中複雜度高的問題(對資料量大的使用者)。所以那我直接廣播一批共享的資料,不用GP不就好了。
  • 總而言之,作者還是提出了一種PFL的方法。(程式碼沒看懂,積分看著也頭大,反正我是不會用的)