PyTorch - Caricamento dati

PyTorch include un pacchetto chiamato torchvision che viene utilizzato per caricare e preparare il set di dati. Include due funzioni di base, ovvero Dataset e DataLoader, che aiutano nella trasformazione e nel caricamento del set di dati.

Set di dati

Il set di dati viene utilizzato per leggere e trasformare un punto dati dal set di dati specificato. La sintassi di base da implementare è menzionata di seguito:

trainset = torchvision.datasets.CIFAR10(root = './data', train = True,
   download = True, transform = transform)

DataLoader viene utilizzato per mescolare e raggruppare i dati. Può essere utilizzato per caricare i dati in parallelo con i worker multiprocessing.

trainloader = torch.utils.data.DataLoader(trainset, batch_size = 4,
   shuffle = True, num_workers = 2)

Esempio: caricamento del file CSV

Usiamo il pacchetto Python Panda per caricare il file csv. Il file originale ha il seguente formato: (nome dell'immagine, 68 punti di riferimento - ogni punto di riferimento ha coordinate ax, y).

landmarks_frame = pd.read_csv('faces/face_landmarks.csv')

n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:].as_matrix()
landmarks = landmarks.astype('float').reshape(-1, 2)