Pytorch学习6:数据的读入和处理

Pytorch

Posted by Shaozi on May 6, 2018

这一部分将是我这个教程的最后一节,其他的内容还是建议大家去pytorch的官网上进行查阅,因为实在是太多。完整的看一遍意义不大,这篇教程之后将以各种实例的形式介绍pytorch的使用。不再阅读官方的教程。

完成这一部分的demo需要在conda环境下添加两个新的工具包:

conda install scikit-image  #用于图像的读取与转换
conda install pandas #用于csv文件的读取

这个demo将实现一个面部表情的获取,数据下载链接 将解压后的数据库和代码文件放在一起。 图像的标注格式如下所示:

image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y
0805personali01.jpg,27,83,27,98, ... 84,134
1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312

第一列是图像文件的名字,从第二列开始,每两列表示一个点的坐标,数据库使用68个点来表示人脸的表情

读取CSV文件的代码如下:

from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode
landmarks_frame = pd.read_csv('faces/face_landmarks.csv')

n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:].as_matrix()
landmarks = landmarks.astype('float').reshape(-1, 2)

print('Image name: {}'.format(img_name))
print('Landmarks shape: {}'.format(landmarks.shape))
print('First 4 Landmarks: {}'.format(landmarks[:4]))

pytorch 提供了一个抽象类,用于处理任意的数据库:torch.utils.data.Dataset 提供了两个方法,第一个__len__用于的到数据库的长度(len(dataset));第二个是__getitem__用于得到某一个位置的数据(dataset[i]) ··· class FaceLandmarksDataset(Dataset): “"”Face Landmarks dataset.”””

def __init__(self, csv_file, root_dir, transform=None):
    """
    Args:
        csv_file (string): Path to the csv file with annotations.
        root_dir (string): Directory with all the images.
        transform (callable, optional): Optional transform to be applied
            on a sample.
    """
    self.landmarks_frame = pd.read_csv(csv_file)
    self.root_dir = root_dir
    self.transform = transform

def __len__(self):
    return len(self.landmarks_frame)

def __getitem__(self, idx):
    img_name = os.path.join(self.root_dir,
                            self.landmarks_frame.iloc[idx, 0])
    image = io.imread(img_name)
    landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix()
    landmarks = landmarks.astype('float').reshape(-1, 2)
    sample = {'image': image, 'landmarks': landmarks}

    if self.transform:
        sample = self.transform(sample)

    return sample ··· 下面调用该类读取数据库,并且输出前四个样本: ``` face_dataset = FaceLandmarksDataset(csv_file='faces/face_landmarks.csv',
                                root_dir='faces/')

fig = plt.figure()

for i in range(len(face_dataset)): sample = face_dataset[i]

print(i, sample['image'].shape, sample['landmarks'].shape)

ax = plt.subplot(1, 4, i + 1)
plt.tight_layout()
ax.set_title('Sample #{}'.format(i))
ax.axis('off')
show_landmarks(**sample)

if i == 3:
    plt.show()
    break ``` [官网教程](https://pytorch.org/tutorials/beginner/data_loading_tutorial.html)上还有一些图像转换增强的类,可以参考.