11import { TFRecordsImageMessage , Features , Feature ,
22 BytesList , Int64List , FloatList } from "./tensorFlowRecordsProtoBuf_pb" ;
33import { crc32c , getInt32Buffer , getInt64Buffer , maskCrc , textEncode } from "./tensorFlowHelpers" ;
4- import { Transform , Readable , finished } from "stream" ;
4+ import { Transform , Readable , TransformOptions , finished } from "stream" ;
55
66// Conditionally import fs for Node.js environments
77let fs : typeof import ( "fs" ) | null = null ;
@@ -11,16 +11,36 @@ try {
1111 // Not available in browser
1212}
1313
14- export interface ITFRecordsFileWriter {
15- write ( record : Buffer ) : boolean ;
16- end ( ) : Promise < void > ;
17- }
18-
1914export interface TransformStreamOptions {
2015 highWaterMark ?: number ;
2116 filePath ?: string ;
2217}
2318
19+ /**
20+ * A Transform stream for TFRecords with an optional `finished` promise
21+ * that resolves when the stream (and any piped file) is complete.
22+ */
23+ export class TFRecordsTransform extends Transform {
24+ /**
25+ * A promise that resolves when the stream is finished writing.
26+ * When piped to a file, this waits for the file to be fully written.
27+ */
28+ public finished : Promise < void > ;
29+
30+ constructor ( options ?: TransformOptions , fileStream ?: NodeJS . WritableStream ) {
31+ super ( options ) ;
32+
33+ // If there's a file stream, wait for it to finish; otherwise wait for this transform
34+ const streamToWatch = fileStream || this ;
35+ this . finished = new Promise ( ( resolve , reject ) => {
36+ finished ( streamToWatch , ( err ) => {
37+ if ( err ) reject ( err ) ;
38+ else resolve ( ) ;
39+ } ) ;
40+ } ) ;
41+ }
42+ }
43+
2444/**
2545 * @name - TFRecords Feature Type
2646 * @description - Defines the type of TFRecords Feature
@@ -74,60 +94,48 @@ export class TFRecordsBuilder {
7494
7595 /**
7696 * @description - Create a Transform stream for TFRecords.
77- * Optionally writes directly to disk when filePath is provided.
97+ * Optionally pipes output directly to disk when filePath is provided.
7898 * @param optionsOrHighWaterMark - Stream buffer size (number) or options object
7999 * @param options.highWaterMark - Stream buffer size
80- * @param options.filePath - When provided, pipes output to this file (Node.js only)
81- * @returns - Transform stream, or ITFRecordsFileWriter when filePath is provided
100+ * @param options.filePath - When provided, pipes output to this file (Node.js only).
101+ * Use stream.finished promise to know when done.
102+ * @returns - TFRecordsTransform stream with a `finished` promise
82103 */
83- public static transformStream ( ) : Transform ;
84- public static transformStream ( highWaterMark : number ) : Transform ;
85- public static transformStream ( options : { highWaterMark ?: number } ) : Transform ;
86- public static transformStream ( options : { filePath : string ; highWaterMark ?: number } ) : ITFRecordsFileWriter ;
87- public static transformStream ( optionsOrHighWaterMark ?: TransformStreamOptions | number ) : Transform | ITFRecordsFileWriter {
104+ public static transformStream ( optionsOrHighWaterMark ?: TransformStreamOptions | number ) : TFRecordsTransform {
88105 const options : TransformStreamOptions | undefined =
89106 typeof optionsOrHighWaterMark === "number"
90107 ? { highWaterMark : optionsOrHighWaterMark }
91108 : optionsOrHighWaterMark ;
92109
93- const transformer = new Transform ( {
94- transform : ( record : Buffer , encoding , callback ) => {
95- const length = record . length ;
110+ let fileStream : NodeJS . WritableStream | undefined ;
111+ if ( options ?. filePath ) {
112+ if ( ! fs ) {
113+ throw new Error ( "File output is only available in Node.js. Use transformStream() without filePath in the browser." ) ;
114+ }
115+ fileStream = fs . createWriteStream ( options . filePath ) ;
116+ }
96117
97- // Get TFRecords CRCs for TFRecords Header and Footer
98- const bufferLength = getInt64Buffer ( length ) ;
99- const bufferLengthMaskedCRC = getInt32Buffer ( maskCrc ( crc32c ( bufferLength ) ) ) ;
100- const bufferDataMaskedCRC = getInt32Buffer ( maskCrc ( crc32c ( record ) ) ) ;
101- callback ( undefined , Buffer . concat ( [ bufferLength , bufferLengthMaskedCRC , record , bufferDataMaskedCRC ] ) ) ;
118+ const transformer = new TFRecordsTransform (
119+ {
120+ transform : ( record : Buffer , encoding , callback ) => {
121+ const length = record . length ;
122+
123+ // Get TFRecords CRCs for TFRecords Header and Footer
124+ const bufferLength = getInt64Buffer ( length ) ;
125+ const bufferLengthMaskedCRC = getInt32Buffer ( maskCrc ( crc32c ( bufferLength ) ) ) ;
126+ const bufferDataMaskedCRC = getInt32Buffer ( maskCrc ( crc32c ( record ) ) ) ;
127+ callback ( undefined , Buffer . concat ( [ bufferLength , bufferLengthMaskedCRC , record , bufferDataMaskedCRC ] ) ) ;
128+ } ,
129+ highWaterMark : options ?. highWaterMark ,
102130 } ,
103- highWaterMark : options ?. highWaterMark ,
104- } ) ;
131+ fileStream ,
132+ ) ;
105133
106- if ( ! options ?. filePath ) {
107- return transformer ;
134+ if ( fileStream ) {
135+ transformer . pipe ( fileStream ) ;
108136 }
109137
110- // File output mode
111- if ( ! fs ) {
112- throw new Error ( "File output is only available in Node.js. Use transformStream() without filePath in the browser." ) ;
113- }
114-
115- const fileStream = fs . createWriteStream ( options . filePath ) ;
116- transformer . pipe ( fileStream ) ;
117-
118- return {
119- write : ( record : Buffer ) => transformer . write ( record ) ,
120- end : ( ) => new Promise < void > ( ( resolve , reject ) => {
121- transformer . end ( ) ;
122- finished ( fileStream , ( err ) => {
123- if ( err ) {
124- reject ( err ) ;
125- } else {
126- resolve ( ) ;
127- }
128- } ) ;
129- } ) ,
130- } ;
138+ return transformer ;
131139 }
132140
133141 private features : Features ;
0 commit comments