1. 程式人生 > >keras中models的Squential類的原始碼簡介

keras中models的Squential類的原始碼簡介

keras中最重要的就是models的Sequential類了,下面我結合原始碼以及自己的理解,儘可能的去學習並能夠說明白,原始碼太多,先貼一個fit函式的實現:

    def fit(self, x, y, batch_size=32, nb_epoch=10, verbose=1, callbacks=[],
            validation_split=0., validation_data=None, shuffle=True,
            class_weight=None, sample_weight=None, **kwargs):
        '''

        Args:
            x: 表示輸入可以是narray, 如果是多個輸入,也可以是[narray, narray], 必須有
            y: labels,only a narray, 必須有
            batch_size: mini batch表示多少次更新一次權重,預設是32
            nb_epoch: 需要迭代多少次去訓練這個模型,預設是10
            verbose: 是不是輸出列印log到標準輸出,預設是列印
            callbacks: 回撥函式(暫時不是很理解這個地方怎麼用)
            validation_split: 測試資料的比例,預設是0
            validation_data: 測試資料,tuple(input , lable)預設是空
            shuffle:不懂
            class_weight:不懂
            sample_weight:不懂
            **kwargs: 只有一個候選項就是 'show_accuracy'

        Returns:

        '''
        '''Trains the model for a fixed number of epochs.

        # Arguments
            x: input data, as a Numpy array or list of Numpy arrays
                (if the model has multiple inputs).
            y: labels, as a Numpy array.
            batch_size: integer. Number of samples per gradient update.
            nb_epoch: integer, the number of epochs to train the model.
            verbose: 0 for no logging to stdout,
                1 for progress bar logging, 2 for one log line per epoch.
            callbacks: list of `keras.callbacks.Callback` instances.
                List of callbacks to apply during training.
                See [callbacks](/callbacks).
            validation_split: float (0. < x < 1).
                Fraction of the data to use as held-out validation data.
            validation_data: tuple (X, y) to be used as held-out
                validation data. Will override validation_split.
            shuffle: boolean or str (for 'batch').
                Whether to shuffle the samples at each epoch.
                'batch' is a special option for dealing with the
                limitations of HDF5 data; it shuffles in batch-sized chunks.
            class_weight: dictionary mapping classes to a weight value,
                used for scaling the loss function (during training only).
            sample_weight: Numpy array of weights for
                the training samples, used for scaling the loss function
                (during training only). You can either pass a flat (1D)
                Numpy array with the same length as the input samples
                (1:1 mapping between weights and samples),
                or in the case of temporal data,
                you can pass a 2D array with shape (samples, sequence_length),
                to apply a different weight to every timestep of every sample.
                In this case you should make sure to specify
                sample_weight_mode="temporal" in compile().

        # Returns
            A `History` object. Its `History.history` attribute is
            a record of training loss values and metrics values
            at successive epochs, as well as validation loss values
            and validation metrics values (if applicable).
        '''
        if self.model is None:
            raise Exception('The model needs to be compiled before being used.')
        if 'show_accuracy' in kwargs:
            kwargs.pop('show_accuracy')
            warnings.warn('The "show_accuracy" argument is deprecated, '
                          'instead you should pass the "accuracy" metric to '
                          'the model at compile time:\n'
                          '`model.compile(optimizer, loss, '
                          'metrics=["accuracy"])`')
        if kwargs:
            raise Exception('Received unknown keyword arguments: ' +
                            str(kwargs))
        return self.model.fit(x, y,
                              batch_size=batch_size,
                              nb_epoch=nb_epoch,
                              verbose=verbose,
                              callbacks=callbacks,
                              validation_split=validation_split,
                              validation_data=validation_data,
                              shuffle=shuffle,
                              class_weight=class_weight,
                              sample_weight=sample_weight)
主要是學會怎麼使用,因為這段程式碼放到整個類中去看才有意義,所以,後續繼續補充吧, 發現欠了好多債了,後續需要補充的東西太多了,逼我把原始碼看完的節奏。