微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

python MNIST手写识别数据调用API的方法

这篇文章主要介绍了python MNIST手写识别数据调用API的方法,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧

MNIST数据集比较小,一般入门机器学习都会采用这个数据集来训练

下载地址:yann.lecun.com/exdb/mnist/

有4个有用的文件

train-images-idx3-ubyte: training set images

train-labels-idx1-ubyte: training set labels

t10k-images-idx3-ubyte: test set images

t10k-labels-idx1-ubyte: test set labels

The training set contains 60000 examples, and the test set 10000 examples. 数据集存储是用binary file存储的,黑白图片

下面给出load数据集的代码

import os import struct import numpy as np import matplotlib.pyplot as plt def load_mnist(): ''' Load mnist data http://yann.lecun.com/exdb/mnist/ 60000 training examples 10000 test sets Arguments: kind: 'train' or 'test', string charater input with a default value 'train' Return: xxx_images: n*m array, n is the sample count, m is the feature number which is 28*28 xxx_labels: class labels for each image, (0-9) ''' root_path = '/home/cc/deep_learning/data_sets/mnist' train_labels_path = os.path.join(root_path, 'train-labels.idx1-ubyte') train_images_path = os.path.join(root_path, 'train-images.idx3-ubyte') test_labels_path = os.path.join(root_path, 't10k-labels.idx1-ubyte') test_images_path = os.path.join(root_path, 't10k-images.idx3-ubyte') with open(train_labels_path, 'rb') as lpath: # '>' denotes bigedian # 'I' denotes unsigned char magic, n = struct.unpack('>II', lpath.read(8)) #loaded = np.fromfile(lpath, dtype = np.uint8) train_labels = np.fromfile(lpath, dtype = np.uint8).astype(np.float) with open(train_images_path, 'rb') as ipath: magic, num, rows, cols = struct.unpack('>IIII', ipath.read(16)) loaded = np.fromfile(train_images_path, dtype = np.uint8) # images start from the 16th bytes train_images = loaded[16:].reshape(len(train_labels), 784).astype(np.float) with open(test_labels_path, 'rb') as lpath: # '>' denotes bigedian # 'I' denotes unsigned char magic, n = struct.unpack('>II', lpath.read(8)) #loaded = np.fromfile(lpath, dtype = np.uint8) test_labels = np.fromfile(lpath, dtype = np.uint8).astype(np.float) with open(test_images_path, 'rb') as ipath: magic, num, rows, cols = struct.unpack('>IIII', ipath.read(16)) loaded = np.fromfile(test_images_path, dtype = np.uint8) # images start from the 16th bytes test_images = loaded[16:].reshape(len(test_labels), 784) return train_images, train_labels, test_images, test_labels

再看看图片集是什么样的:

def test_mnist_data(): ''' Just to check the data Argument: none Return: none ''' train_images, train_labels, test_images, test_labels = load_mnist() fig, ax = plt.subplots(nrows = 2, ncols = 5, sharex = True, sharey = True) ax =ax.flatten() for i in range(10): img = train_images[i][:].reshape(28, 28) ax[i].imshow(img, cmap = 'Greys', interpolation = 'nearest') print('corresponding labels = %d' %train_labels[i]) if __name__ == '__main__': test_mnist_data()

跑出的结果如下:

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持编程之家。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐