adam coates and colleagues have a string of very interesting papers where they propose using k-means to train convolutional networks. training the first layer boils down to about 10 lines of python code (thanks to sklearn’s implementation of k-means).
here are the filters i got by training it on the cifar-10 dataset:
and here are first layer filters from cuda-convnet trained on the same data (but in a supervised manner):
of course what matters in the end is the overall accuracy of the network, but at least qualitatively the filters look similar, and it’s pretty fun that you can get this sort of output from k-means.
here’s the code if you want to try it out yourself (and yes, that’s more than 10 lines, but most of it is just set up + loading data + displaying results):
import cPickle as pickle import numpy as np from matplotlib import pyplot as plt from os.path import join from sklearn.cluster import KMeans # download data here: http://www.cs.toronto.edu/~kriz/cifar.html with open(join('data','cifar-10-batches-py','data_batch_1'), 'rb') as f: data = pickle.load(f) images = data['data'].reshape((-1,3,32,32)).astype('float32')/255 images = np.rollaxis(images, 1, 4) # collect patches patches = np.zeros((0,5,5,3)) for x in range(0,32-5,5): for y in range(0,32-5,5): patches = np.concatenate( (patches, images[:,x:x+5,y:y+5,:]), axis=0) patches = patches.reshape((patches.shape,-1)) # normalize mu = patches.mean(axis=0) sigma = patches.std(axis=0) + np.ptp(patches, axis=0)/20.0 patches = (patches-mu[np.newaxis,:])/(sigma[np.newaxis,:]) # zca whiten eig_values, eig_vec = np.linalg.eig(np.cov(patches.T)) zca = eig_vec.dot(np.diag((eig_values+0.01)**-0.5).dot(eig_vec.T)) patches = np.dot(patches, zca) # k-means NUM_FILTERS = 64 km = KMeans( n_clusters=NUM_FILTERS, n_jobs=1, random_state=0, n_init=1, verbose=True) km.fit(patches) filters = km.cluster_centers_.reshape((NUM_FILTERS,5,5,3)) # display fig = plt.figure() num_col = int(np.ceil(float(NUM_FILTERS)/4)) for i in xrange(NUM_FILTERS): ax = fig.add_subplot(4, num_col, i) filter_ = filters[i,...] filter_ -= filter_.min() filter_ /= filter_.max() ax.imshow(filter_, interpolation='none') ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False)