''' ab_ucf50_groups.py Jason Corso Train and test an SVM on the UCF 50 data. Does group-wise 5-fold cross-validation to avoid mixing groups between train-test sets. The processed UCF 50 data is available at http://www.cse.buffalo.edu/~jcorso/r/actionbank MAKE sure that ../code is in your PYTHONPATH, i.e., export PYTHONPATH=../code before running this script ''' import argparse import glob import gzip import numpy as np import os import os.path import random as rnd import scipy.io as sio import multiprocessing as multi from actionbank import * import ab_svm def testgroups(groups,training,testing): # main group-wise testing routine print "Training", training print "Testing", testing fp = gzip.open(groups[0][0][0]) vlen = len(np.load(fp)) fp.close() print "vector length is %d"%vlen nt = 0 for i in training: nt += len(groups[i]) print "have %d training files"%nt Dtrain = np.zeros( (nt,vlen), np.uint8 ) Ytrain = np.ones ( (nt) ) * -1000 ti = 0 for i in training: for j in groups[i]: # j[0] is path and j[1] is class fp = gzip.open(j[0]) Dtrain[ti][:] = np.load(fp) fp.close() Ytrain[ti] = j[1] ti += 1 ne = 0 for i in testing: ne += len(groups[i]) print "have %d testing files"%ne Dtest = np.zeros( (ne,vlen), np.uint8 ) Ytest = np.ones ( (ne) ) * -1000 ti = 0 for i in testing: for j in groups[i]: # j[0] is path and j[1] is class fp = gzip.open(j[0]) Dtest[ti][:] = np.load(fp) fp.close() Ytest[ti] = j[1] ti += 1 res=ab_svm.SVMLinear(Dtrain,np.int32(Ytrain),Dtest,threads=multi.cpu_count()-1,useLibLinear=True,useL1R=False) tp=np.sum(res==Ytest) print 'Accuracy is %.1f%%' % ((np.float64(tp)/Dtest.shape[0])*100) return ((np.float64(tp)/Dtest.shape[0])*100) if __name__ == '__main__': parser = argparse.ArgumentParser(description="Script to perform 5-fold group-wise cross-validation on the UCF 50 data set using the included SVM code.", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("root", help="path to the directory containing the action bank processed ucf 50 files structured as in root/class/name_g00_c00_banked.npy.gz for each class") args = parser.parse_args() vlen = 0 cdir = os.listdir(args.root) if (len(cdir) != 50): print "error: found %d classes, but there should be 50"%(len(cdir)) groups = [] for g in range(1,26): gset = [] for ci,cl in enumerate(cdir): files = glob.glob(os.path.join(args.root,cl,'*g%02d*%s'%(g,banked_suffix))) for f in files: gset.append( [f,ci] ) print "group %d has %d"%(g,len(gset)) groups.append(gset) full = np.arange(25) sets = [] sets.append(np.arange(0,5)) sets.append(np.arange(5,10)) sets.append(np.arange(10,15)) sets.append(np.arange(15,20)) sets.append(np.arange(20,25)) accs = np.zeros((5)) for i in range(5): accs[i] = testgroups(groups,np.setdiff1d(full,sets[i]),sets[i]) print "mean accuracy is %f"%(accs.mean())