在使用TensorFlow訓練神經網絡時,首先面臨的問題是:網絡的輸入
此篇文章,教大家將自己的數據集制作成TFRecord格式,feed進網絡,除了TFRecord格式,TensorFlow也支持其他格
式的數據,此處就不再介紹了。建議大家使用TFRecord格式,在后面可以通過api進行多線程的讀取文件隊列。
1. 原本的數據集
此時,我有兩類圖片,分別是xiansu100,xiansu60,每一類中有10張圖片。
2.制作成TFRecord格式
tfrecord會根據你選擇輸入文件的類,自動給每一類打上同樣的標簽。如在本例中,只有0,1 兩類,想知道文件夾名與label關系的,可以自己保存起來。
#生成整數型的屬性 def _int64_feature(value): return tf.train.Feature(int64_list = tf.train.Int64List(value = [value])) #生成字符串類型的屬性 def _bytes_feature(value): return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value])) #制作TFRecord格式 def createTFRecord(filename,mapfile): class_map = {} data_dir = '/home/wc/DataSet/traffic/testTFRecord/' classes = {'xiansu60','xiansu100'} #輸出TFRecord文件的地址 writer = tf.python_io.TFRecordWriter(filename) for index,name in enumerate(classes): class_path=data_dir+name+'/' class_map[index] = name for img_name in os.listdir(class_path): img_path = class_path + img_name #每個圖片的地址 img = Image.open(img_path) img= img.resize((224,224)) img_raw = img.tobytes() #將圖片轉化成二進制格式 example = tf.train.Example(features = tf.train.Features(feature = { 'label':_int64_feature(index), 'image_raw': _bytes_feature(img_raw) })) writer.write(example.SerializeToString()) writer.close() txtfile = open(mapfile,'w+') for key in class_map.keys(): txtfile.writelines(str(key)+":"+class_map[key]+"\n") txtfile.close()