diff --git a/doc/flags.md b/doc/flags.md index f2652bbf..1d106882 100644 --- a/doc/flags.md +++ b/doc/flags.md @@ -57,3 +57,4 @@ The sampling script `sample.lua` accepts the following command-line flags: - `-gpu`: The ID of the GPU to use (zero-indexed). Default is 0. Set this to -1 to run in CPU-only mode. - `-gpu_backend`: The GPU backend to use; either `cuda` or `opencl`. Default is `cuda`. - `-verbose`: By default just the sampled text is printed to the console. Set this to 1 to also print some diagnostic information. +- `-output`: By default sampled text gets output to standard output, together with diagnostic messages. This option specifies the file to output the result. diff --git a/sample.lua b/sample.lua index 4e6ebae0..00024fc5 100644 --- a/sample.lua +++ b/sample.lua @@ -13,6 +13,7 @@ cmd:option('-temperature', 1) cmd:option('-gpu', 0) cmd:option('-gpu_backend', 'cuda') cmd:option('-verbose', 0) +cmd:option('-output', '-') local opt = cmd:parse(arg) @@ -39,4 +40,11 @@ if opt.verbose == 1 then print(msg) end model:evaluate() local sample = model:sample(opt) -print(sample) +if opt.output == "-" then + print(sample) +else + require 'io' + local outfile = io.open(opt.output, "w") + outfile:write(sample) + outfile:close() +end