Skip to content

Commit 2f9777e

Browse files
author
SivilTaram
committed
update pipeline & readme
1 parent c25f350 commit 2f9777e

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

awakening_latent_grounding/README.md

+4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# README
2+
3+
This code is still incomplete. We will actively update it when the code is totally ready.
4+
15
# Environment Setup
26

37
```

awakening_latent_grounding/inference/pipeline_torchscript.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,24 @@
55
import torch
66
from .pipeline_base import NLBindingInferencePipeline
77

8+
89
class NLBindingTorchScriptPipeline(NLBindingInferencePipeline):
910
def __init__(self,
10-
model_dir: str,
11-
greedy_linking: bool,
12-
threshold: float=0.2,
13-
num_threads: int=8,
14-
use_gpu: bool = torch.cuda.is_available()
15-
) -> None:
11+
model_dir: str,
12+
greedy_linking: bool,
13+
threshold: float = 0.2,
14+
num_threads: int = 8,
15+
use_gpu: bool = torch.cuda.is_available()
16+
) -> None:
1617
super().__init__(model_dir, greedy_linking=greedy_linking, threshold=threshold)
1718

1819
self.device = torch.device('cuda') if use_gpu else torch.device('cpu')
1920
model_file = "nl_binding.script.bin"
2021

2122
torch.set_num_interop_threads(2)
2223
torch.set_num_threads(num_threads)
23-
print('Torch model Threads: {}, {}, {}'.format(torch.get_num_interop_threads(), torch.get_num_threads(), self.device))
24+
print('Torch model Threads: {}, {}, {}'.format(torch.get_num_interop_threads(), torch.get_num_threads(),
25+
self.device))
2426
model_ckpt_path = os.path.join(model_dir, model_file)
2527
self.model = torch.jit.load(model_ckpt_path, map_location=self.device)
2628
self.model.eval()

0 commit comments

Comments
 (0)