1. 程式人生 > >java生成tfrecord檔案

java生成tfrecord檔案

最近在做一個ner識別模型,用演算法標註了一億出頭的資料,剛剛開始用python生成tfrecord檔案,這一步非常的費時間,一億條資料每一千萬生成一個檔案差不多要兩個半小時的時間才能生成完成,左思右想想找一種快的方法,前幾天看到美團公眾號發了一篇文章關於ctr模型,其中有一句話印象深刻用spark生成tfrecord檔案,雖然一筆帶過,受到啟發,既然是jvm系的,java一定可以,然後就查了點資料,和第一家公司的同事討論了下,原來確實可以用protobuf來生成java程式碼,然後通過java程式碼生成tfrecord檔案,在python讀取生成的tfrecord檔案然後餵給模型,這一步大概由原來的兩個半小時變成12分鐘,中間的gap可想而知,我估計用c++更快,改天用c++再寫一篇,下面看java生成tfrecord的一些步驟:

前提是電腦上必須裝protobuf,我裝的是這個版本,關於怎麼裝我在前面的部落格已經提到:

zhoumeixudeMacBook-Pro:ner zhoumeixu$ protoc --version
libprotoc 3.5.1

首先是官網上下載兩個proto檔案 ,然後用protobuf生成java程式碼,地址如下:

用protobuf生成java程式碼,這個和thrift是差不多的:

protoc  --java_out=.  feature.proto 
protoc  --java_out=.  example.proto 

不出意外會生成一個org/tensorflow/example的資料夾,裡面的程式碼是,考慮到一個java工程下面去,package幾時如下:

這裡是不是非常熟悉,和python裡面的生成tfrecord檔案所涉及到類是不是都一樣,原因就是持久化都是使用的protobuf協議,python也只是一層殼

在java裡面使用還需要新增protobuf java包,pom.xml配置如下:

    <dependency>
			<groupId>com.google.protobuf</groupId>
			<artifactId>protobuf-java</artifactId>
			<version>3.5.1</version>
		</dependency>

		<!-- https://mvnrepository.com/artifact/com.google.guava/guava -->
		<dependency>
			<groupId>com.google.guava</groupId>
			<artifactId>guava</artifactId>
			<version>23.6-jre</version>
		</dependency>

生成tfrecord檔案的java程式碼,主要涉及到兩個,一個是重寫TFRecordFileWriter、一個是 TFRecordNer,下面看程式碼:

TFRecordFileWriter:

package com.xxx.protobuf;

import java.io.BufferedOutputStream;
import java.io.Closeable;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;


public class TFRecordFileWriter implements Closeable {
  
/**
 * Implements CRC32-C as defined in: "Optimization of Cyclic Redundancy-CHeck Codes with 24 and 32
 * Parity Bits", IEEE Transactions on Communications 41(6): 883-892 (1993).
 *
 * The implementation of this class has been sourced from the Appendix of RFC 3309, but with masking
 * due to Java not being able to support unsigned types.
 */
private static class Crc32c {
  private static final int[] CRC_TABLE = {0x00000000, 0xF26B8303, 0xE13B70F7, 0x1350F3F4,
      0xC79A971F, 0x35F1141C, 0x26A1E7E8, 0xD4CA64EB, 0x8AD958CF, 0x78B2DBCC, 0x6BE22838,
      0x9989AB3B, 0x4D43CFD0, 0xBF284CD3, 0xAC78BF27, 0x5E133C24, 0x105EC76F, 0xE235446C,
      0xF165B798, 0x030E349B, 0xD7C45070, 0x25AFD373, 0x36FF2087, 0xC494A384, 0x9A879FA0,
      0x68EC1CA3, 0x7BBCEF57, 0x89D76C54, 0x5D1D08BF, 0xAF768BBC, 0xBC267848, 0x4E4DFB4B,
      0x20BD8EDE, 0xD2D60DDD, 0xC186FE29, 0x33ED7D2A, 0xE72719C1, 0x154C9AC2, 0x061C6936,
      0xF477EA35, 0xAA64D611, 0x580F5512, 0x4B5FA6E6, 0xB93425E5, 0x6DFE410E, 0x9F95C20D,
      0x8CC531F9, 0x7EAEB2FA, 0x30E349B1, 0xC288CAB2, 0xD1D83946, 0x23B3BA45, 0xF779DEAE,
      0x05125DAD, 0x1642AE59, 0xE4292D5A, 0xBA3A117E, 0x4851927D, 0x5B016189, 0xA96AE28A,
      0x7DA08661, 0x8FCB0562, 0x9C9BF696, 0x6EF07595, 0x417B1DBC, 0xB3109EBF, 0xA0406D4B,
      0x522BEE48, 0x86E18AA3, 0x748A09A0, 0x67DAFA54, 0x95B17957, 0xCBA24573, 0x39C9C670,
      0x2A993584, 0xD8F2B687, 0x0C38D26C, 0xFE53516F, 0xED03A29B, 0x1F682198, 0x5125DAD3,
      0xA34E59D0, 0xB01EAA24, 0x42752927, 0x96BF4DCC, 0x64D4CECF, 0x77843D3B, 0x85EFBE38,
      0xDBFC821C, 0x2997011F, 0x3AC7F2EB, 0xC8AC71E8, 0x1C661503, 0xEE0D9600, 0xFD5D65F4,
      0x0F36E6F7, 0x61C69362, 0x93AD1061, 0x80FDE395, 0x72966096, 0xA65C047D, 0x5437877E,
      0x4767748A, 0xB50CF789, 0xEB1FCBAD, 0x197448AE, 0x0A24BB5A, 0xF84F3859, 0x2C855CB2,
      0xDEEEDFB1, 0xCDBE2C45, 0x3FD5AF46, 0x7198540D, 0x83F3D70E, 0x90A324FA, 0x62C8A7F9,
      0xB602C312, 0x44694011, 0x5739B3E5, 0xA55230E6, 0xFB410CC2, 0x092A8FC1, 0x1A7A7C35,
      0xE811FF36, 0x3CDB9BDD, 0xCEB018DE, 0xDDE0EB2A, 0x2F8B6829, 0x82F63B78, 0x709DB87B,
      0x63CD4B8F, 0x91A6C88C, 0x456CAC67, 0xB7072F64, 0xA457DC90, 0x563C5F93, 0x082F63B7,
      0xFA44E0B4, 0xE9141340, 0x1B7F9043, 0xCFB5F4A8, 0x3DDE77AB, 0x2E8E845F, 0xDCE5075C,
      0x92A8FC17, 0x60C37F14, 0x73938CE0, 0x81F80FE3, 0x55326B08, 0xA759E80B, 0xB4091BFF,
      0x466298FC, 0x1871A4D8, 0xEA1A27DB, 0xF94AD42F, 0x0B21572C, 0xDFEB33C7, 0x2D80B0C4,
      0x3ED04330, 0xCCBBC033, 0xA24BB5A6, 0x502036A5, 0x4370C551, 0xB11B4652, 0x65D122B9,
      0x97BAA1BA, 0x84EA524E, 0x7681D14D, 0x2892ED69, 0xDAF96E6A, 0xC9A99D9E, 0x3BC21E9D,
      0xEF087A76, 0x1D63F975, 0x0E330A81, 0xFC588982, 0xB21572C9, 0x407EF1CA, 0x532E023E,
      0xA145813D, 0x758FE5D6, 0x87E466D5, 0x94B49521, 0x66DF1622, 0x38CC2A06, 0xCAA7A905,
      0xD9F75AF1, 0x2B9CD9F2, 0xFF56BD19, 0x0D3D3E1A, 0x1E6DCDEE, 0xEC064EED, 0xC38D26C4,
      0x31E6A5C7, 0x22B65633, 0xD0DDD530, 0x0417B1DB, 0xF67C32D8, 0xE52CC12C, 0x1747422F,
      0x49547E0B, 0xBB3FFD08, 0xA86F0EFC, 0x5A048DFF, 0x8ECEE914, 0x7CA56A17, 0x6FF599E3,
      0x9D9E1AE0, 0xD3D3E1AB, 0x21B862A8, 0x32E8915C, 0xC083125F, 0x144976B4, 0xE622F5B7,
      0xF5720643, 0x07198540, 0x590AB964, 0xAB613A67, 0xB831C993, 0x4A5A4A90, 0x9E902E7B,
      0x6CFBAD78, 0x7FAB5E8C, 0x8DC0DD8F, 0xE330A81A, 0x115B2B19, 0x020BD8ED, 0xF0605BEE,
      0x24AA3F05, 0xD6C1BC06, 0xC5914FF2, 0x37FACCF1, 0x69E9F0D5, 0x9B8273D6, 0x88D28022,
      0x7AB90321, 0xAE7367CA, 0x5C18E4C9, 0x4F48173D, 0xBD23943E, 0xF36E6F75, 0x0105EC76,
      0x12551F82, 0xE03E9C81, 0x34F4F86A, 0xC69F7B69, 0xD5CF889D, 0x27A40B9E, 0x79B737BA,
      0x8BDCB4B9, 0x988C474D, 0x6AE7C44E, 0xBE2DA0A5, 0x4C4623A6, 0x5F16D052, 0xAD7D5351,};

  private static final long LONG_MASK = 0xFFFFFFFFL;
  private static final int BYTE_MASK = 0xFF;

  private int crc = ~0;

  public void update(int b) {
    crc = crc32c(crc, b);
  }

  public void update(byte[] buffer, int offset, int length) {
    for (int i = offset; i < offset + length; i++) {
      crc = crc32c(crc, buffer[i]);
    }
  }

  public long getValue() {
    return (crc ^ LONG_MASK) & LONG_MASK;
  }

  public void reset() {
    crc = ~0;
  }

  private static int crc32c(int crc, int b) {
    return crc >>> 8 ^ CRC_TABLE[(crc ^ b & BYTE_MASK) & BYTE_MASK];
  }
}

  static public final void writeLong(BufferedOutputStream out, long v, byte[] writeBuffer)
      throws IOException {
    writeBuffer[7] = (byte) (v >>> 56);
    writeBuffer[6] = (byte) (v >>> 48);
    writeBuffer[5] = (byte) (v >>> 40);
    writeBuffer[4] = (byte) (v >>> 32);
    writeBuffer[3] = (byte) (v >>> 24);
    writeBuffer[2] = (byte) (v >>> 16);
    writeBuffer[1] = (byte) (v >>> 8);
    writeBuffer[0] = (byte) (v >>> 0);
    out.write(writeBuffer, 0, 8);
  }

  static public final void writeInt(BufferedOutputStream out, int v, byte[] writeBuffer)
      throws IOException {
    writeBuffer[3] = (byte) (v >>> 24);
    writeBuffer[2] = (byte) (v >>> 16);
    writeBuffer[1] = (byte) (v >>> 8);
    writeBuffer[0] = (byte) (v >>> 0);
    out.write(writeBuffer, 0, 4);
  }

  private final BufferedOutputStream out;

  TFRecordFileWriter(Path p) throws IOException {
    out = new BufferedOutputStream(Files.newOutputStream(p));
  }
  
  

  TFRecordFileWriter(String path) throws IOException {
    out = new BufferedOutputStream(new FileOutputStream(path));
  }

  void writeExample(org.tensorflow.example.Example example) throws IOException {
    byte[] writeBuffer = new byte[8];
    int len = example.getSerializedSize();
    writeLong(out, len, writeBuffer);
    int crc = TFRecordFileWriter.MaskCrc(writeBuffer, 8);
    writeInt(out, crc, writeBuffer);
    byte[] body = example.toByteArray();
    out.write(body);
    crc = TFRecordFileWriter.MaskCrc(body, body.length);
    writeInt(out, crc, writeBuffer);
  }

  private static final long kMaskDelta = 0xa282ead8;

  private static int MaskCrc(byte[] buffer, int length) {
    Crc32c c1 = new Crc32c();
    c1.update(buffer, 0, length);
    long crc = c1.getValue();
    // Rotate right by 15 bits and add a constant.
    long l = (((crc >> 15) | (crc << 17)) + kMaskDelta) & 0xFFFFFFFF;
    return (int) l;
  }

  public void close() throws IOException {
    if (out != null)
      out.close();
  }

}

TFRecordNer:

package com.xxx.protobuf;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.InputStreamReader;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.tensorflow.example.Example;
import org.tensorflow.example.Feature;
import org.tensorflow.example.Features;
import org.tensorflow.example.Int64List;

import com.google.common.primitives.Longs;

public class TFRecordNer {
	public static void main(String[] args) {
		String inputpath = "/Users/zhoumeixu/Desktop/all.txt";
		String outputpath = "/Users/zhoumeixu/Desktop/ner/tfner.records-%s";
		
		transCorpus(inputpath,outputpath);
		
		
	}
	
	
	

	public static void transCorpus(String inputpath, String outputpath) {
		BufferedReader br = null;
		TFRecordFileWriter tw = null;
		
		int max = 10000000;
		int cn = 1;
		int iter = 1;
		Map<String, Integer> mapinput = getDict("file/nerinput.txt");
		Map<String, Integer> mapoutput = getDict("file/neroutput.txt");

		try {
			br = new BufferedReader(new InputStreamReader(new FileInputStream(inputpath)));
			tw = new TFRecordFileWriter(String.format(outputpath, 0));

			String line = br.readLine();
			long count=0l;

			while (line != null) {
				String[] lines = line.split("\t");
				count++;

				if (lines.length == 2) {

					String[] inputs = lines[0].split("\\|");
					String[] outputs = lines[1].split("\\|");

				if (inputs.length == outputs.length) {
					
						long[] inputarr = new long[80];

						long[] inputlabel = new long[80];
						Arrays.fill(inputlabel, 8);
						
						
						cn++;
						if (cn > max) {
							tw.close();//這裡一定需要
							tw = new TFRecordFileWriter(String.format(outputpath, iter));
							iter++;
							cn = 1;

						}
						
						int len=inputs.length>80?80:inputs.length;

						for(int i=0;i<len;i++) {
							if (inputs[i] == " ") {
								inputarr[i]=mapinput.get("blank");
							} else {
								inputarr[i]=mapinput.getOrDefault(inputs[i], 1);
							}
							
							inputlabel[i]=mapoutput.getOrDefault(outputs[i], 8);
						}
						
						List<Long> inputlist =Longs.asList(inputarr);
						List<Long> labellist =Longs.asList(inputlabel);
						
						
						
						Example.Builder build = Example.newBuilder();
						Feature.Builder input = Feature.newBuilder();
						Feature.Builder label = Feature.newBuilder();

						Int64List.Builder listinput = Int64List.newBuilder();

						listinput.addAllValue(inputlist);
						input.setInt64List(listinput);

						Int64List.Builder listoutput = Int64List.newBuilder();
						listoutput.addAllValue(labellist);
						label.setInt64List(listoutput);

						Features.Builder features = Features.newBuilder();
						features.putFeature("input", input.build());
						features.putFeature("label", label.build());

						build.setFeatures(features);
						
						tw.writeExample(build.build());
						

						if (Math.random() < 0.000001) {
							System.out.println("count is:" + count+"  cn is:"+cn);
						}
					}else {
						// System.out.println(line);
					}
				} 

				line = br.readLine();

			}
			br.close();
			tw.close();

		} catch (Exception e) {
			e.printStackTrace();
		}

	}
	

	public static Map<String, Integer> getDict(String path) {
		BufferedReader br = null;
		Map<String, Integer> map = new HashMap<String, Integer>();
		int i = 0;

		try {
			br = new BufferedReader(new InputStreamReader(new FileInputStream(path)));

			String line = br.readLine();
			while (line != null) {

				map.put(line, i++);

				line = br.readLine();

			}

		} catch (Exception e) {
			e.printStackTrace();
		}

		return map;

	}

}

生成tfrecord檔案python讀取:


import  tensorflow  as tf

def read_and_decode(filename_queue):
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example, features={
        "input": tf.FixedLenFeature([80], tf.int64),
        "label": tf.FixedLenFeature([80], tf.int64)
    })

    x = tf.cast(features["input"], tf.int32)
    y = tf.cast(features["label"], tf.int32)

    return x, y





def inputs(file, batch_size, num_epochs):
    if not num_epochs:
        num_epochs = None
    filename_queue = tf.train.string_input_producer(file)
    feature, label = read_and_decode(filename_queue)
    x, y = tf.train.shuffle_batch([feature, label], batch_size=batch_size, capacity=8888, min_after_dequeue=8887)

    return x, y

if __name__=="__main__":
    sess = tf.Session()
    files = tf.train.match_filenames_once("/Users/zhoumeixu/Desktop/ner/tfner.records-*")
    x, y = inputs(files, 1, 11)
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    print(sess.run(files))
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)
    for i in range(10):
       x_result,y_result= sess.run([x,y])
       print(x_result,y_result)

    coord.request_stop()
    coord.join(threads)
    sess.close()

結果:

b'/Users/zhoumeixu/Desktop/ner/tfner.records-6'
 b'/Users/zhoumeixu/Desktop/ner/tfner.records-7'
 b'/Users/zhoumeixu/Desktop/ner/tfner.records-8'
 b'/Users/zhoumeixu/Desktop/ner/tfner.records-9']
[[ 9157 11163 16887   611 18227   329 11772   258 19841 11163   589 16169
  19681  9956 13159  2257 18369 11586  4867  9922    48     0     0     0
      0     0     0     0     0     0     0     0     0     0     0     0
      0     0     0     0     0     0     0     0     0