TensorFlow之下載和匯入mnists資料集的read_data_sets()錯誤分析(從原始碼的角度)
阿新 • • 發佈:2019-01-23
在用TensorFlow的mnist資料集做手寫數字識別任務時,使用tensorflow自帶的模組(如下所示)下載和匯入資料集會報錯,原因是該模組爬取的資料集網站不能訪問。。因為該模組是用python內建urllib
模組來下載資料的,需要提供有效的資料集網站地址。
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets(data_dir, one_hot=True)
首先我看看read_data_sets()
函式原始碼:
def read_data_sets(train_dir,
fake_data=False,
one_hot=False,
dtype=dtypes.float32,
reshape=True,
validation_size=5000,
seed=None,
source_url=DEFAULT_SOURCE_URL):
if fake_data:
def fake():
return DataSet(
[], [], fake_data=True, one_hot=one_hot, dtype=dtype, seed=seed)
train = fake()
validation = fake()
test = fake()
return base.Datasets(train=train, validation=validation, test=test)
if not source_url: # empty string check
source_url = DEFAULT_SOURCE_URL
TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
TEST_LABELS = 't10k-labels-idx1-ubyte.gz'
local_file = base.maybe_download(TRAIN_IMAGES, train_dir,
source_url + TRAIN_IMAGES)
with gfile.Open(local_file, 'rb') as f:
train_images = extract_images(f)
local_file = base.maybe_download(TRAIN_LABELS, train_dir,
source_url + TRAIN_LABELS)
with gfile.Open(local_file, 'rb') as f:
train_labels = extract_labels(f, one_hot=one_hot)
local_file = base.maybe_download(TEST_IMAGES, train_dir,
source_url + TEST_IMAGES)
with gfile.Open(local_file, 'rb') as f:
test_images = extract_images(f)
local_file = base.maybe_download(TEST_LABELS, train_dir,
source_url + TEST_LABELS)
with gfile.Open(local_file, 'rb') as f:
test_labels = extract_labels(f, one_hot=one_hot)
if not 0 <= validation_size <= len(train_images):
raise ValueError('Validation size should be between 0 and {}. Received: {}.'
.format(len(train_images), validation_size))
validation_images = train_images[:validation_size]
validation_labels = train_labels[:validation_size]
train_images = train_images[validation_size:]
train_labels = train_labels[validation_size:]
options = dict(dtype=dtype, reshape=reshape, seed=seed)
train = DataSet(train_images, train_labels, **options)
validation = DataSet(validation_images, validation_labels, **options)
test = DataSet(test_images, test_labels, **options)
return base.Datasets(train=train, validation=validation, test=test)
在呼叫這個方法的時候是設定source_url
引數的,如果沒有設定,就會使用預設的source_url = DEFAULT_SOURCE_URL
,預設的DEFAULT_SOURCE_URL
在原始碼中也可以找到:
# CVDF mirror of http://yann.lecun.com/exdb/mnist/
DEFAULT_SOURCE_URL = 'https://storage.googleapis.com/cvdf-datasets/mnist/'
有時這個url地址是能用的,報錯的話可以將註釋掉的url地址替換掉下邊的地址,就可以用了。也可以在呼叫read_data_sets
函式的時候指定source_url
引數為下載的資料集網址。
原始碼中下載資料集呼叫的是base.maybe_download
函式,它的原始碼為:
def maybe_download(filename, work_directory, source_url):
"""Download the data from source url, unless it's already here.
Args:
filename: string, name of the file in the directory.
work_directory: string, path to working directory.
source_url: url to download from if file doesn't exist.
Returns:
Path to resulting file.
"""
if not gfile.Exists(work_directory):
gfile.MakeDirs(work_directory)
filepath = os.path.join(work_directory, filename)
if not gfile.Exists(filepath):
temp_file_name, _ = urlretrieve_with_retry(source_url)
gfile.Copy(temp_file_name, filepath)
with gfile.GFile(filepath) as f:
size = f.size()
print('Successfully downloaded', filename, size, 'bytes.')
return filepath
下載資料集又呼叫了urlretrieve_with_retry
函式,原始碼為:
def urlretrieve_with_retry(url, filename=None):
return urllib.request.urlretrieve(url, filename)
通過上述對原始碼的一步步追蹤,最後看到,tensorflow下載資料集使用的python內建的urllib
模組