2
2
__license__ = "Apache-2.0"
3
3
4
4
import os
5
- import itertools
6
- import csv
7
5
import shutil
8
- import click
9
6
import sys
7
+ import click
10
8
from backend_config import (
11
- text_length ,
12
9
max_docs ,
13
- backend_datafile ,
14
- backend_port ,
15
- backend_workdir ,
16
- backend_model ,
10
+ datafile ,
11
+ port ,
12
+ workdir ,
13
+ model
17
14
)
18
15
19
16
from executors .disk_indexer import DiskIndexer
20
- from executors .rankers import ReviewRanker
21
- from executors .encoders import MyTransformer
22
- import random
23
-
24
- from jina import Flow , Document
17
+ from helper import prep_docs
18
+ from jina import Flow
25
19
26
20
try :
27
21
__import__ ("pretty_errors" )
28
22
except ImportError :
29
23
pass
30
24
31
25
32
- def trim_string (
33
- input_string : str , word_count : int = text_length , sep : str = " "
34
- ) -> str :
35
- """
36
- Trim a string to a certain number of words.
37
- :param input_string: string to trim
38
- :param word_count: how many words to trim to
39
- :param sep: separator between words
40
- :return: trimmmed string
41
- """
42
- sanitized_string = input_string .replace ("\\ n" , sep )
43
- words = sanitized_string .split (sep )[:word_count ]
44
- trimmed_string = " " .join (words )
45
-
46
- return trimmed_string
47
-
48
-
49
- def prep_docs (input_file : str , num_docs :int = max_docs ):
26
+ def index (num_docs : int = max_docs ):
50
27
"""
51
- Create generator for every row in csv as a Document
52
- :param input_file: Input csv filename
53
- :return: Generator
28
+ Build an index for your search
29
+ :param num_docs: maximum number of Documents to index
54
30
"""
55
-
56
- with open (input_file , "r" ) as csv_file :
57
- csv_reader = csv .DictReader (csv_file )
58
- input_field = "Description"
59
- for row in itertools .islice (csv_reader , num_docs ):
60
- # Fix invalid ratings and counts
61
- if row ["Average User Rating" ] == "" :
62
- row ["Average User Rating" ] = random .uniform (0.0 , 5.0 )
63
- if row ["User Rating Count" ] == "" :
64
- row ["User Rating Count" ] = random .randint (10 , 10_000 )
65
- # Set field to encode and index
66
- input_data = trim_string (f"{ row ['Name' ]} - { trim_string (row [input_field ])} " )
67
- # Put all of that into a doc
68
- doc = Document (text = input_data )
69
- doc .tags = row
70
- yield doc
71
-
72
-
73
- def index (num_docs = max_docs ):
74
31
flow = (
75
32
Flow ()
76
- # .add(uses='jinahub+docker://TransformerTorchEncoder', pretrained_model_name_or_path="sentence-transformers/msmarco-distilbert-base-v3", name="encoder", max_length=50)
77
33
.add (
78
- uses = MyTransformer ,
79
- pretrained_model_name_or_path = backend_model ,
34
+ uses = "jinahub+docker://TransformerTorchEncoder" ,
35
+ pretrained_model_name_or_path = model ,
80
36
name = "encoder" ,
81
- ).add (uses = DiskIndexer , workspace = backend_workdir , name = "indexer" )
37
+ max_length = 50 ,
38
+ )
39
+ .add (uses = DiskIndexer , workspace = workdir )
82
40
)
83
41
84
42
with flow :
85
43
flow .post (
86
44
on = "/index" ,
87
- inputs = prep_docs (input_file = backend_datafile , num_docs = num_docs ),
45
+ inputs = prep_docs (input_file = datafile , num_docs = num_docs ),
88
46
request_size = 64 ,
89
47
read_mode = "r" ,
90
48
)
91
49
92
50
93
51
def query_restful ():
52
+ """
53
+ Query your index
54
+ """
94
55
flow = (
95
56
Flow ()
96
- # .add(uses='jinahub+docker://TransformerTorchEncoder', pretrained_model_name_or_path="sentence-transformers/msmarco-distilbert-base-v3", name="encoder", max_length=50)
97
57
.add (
98
- uses = MyTransformer ,
99
- pretrained_model_name_or_path = backend_model ,
58
+ uses = "jinahub+docker://TransformerTorchEncoder" ,
59
+ pretrained_model_name_or_path = "sentence-transformers/msmarco-distilbert-base-v3" ,
100
60
name = "encoder" ,
101
- ).add (uses = DiskIndexer , workspace = backend_workdir , name = "indexer" )
102
- # .add(uses=ReviewRanker, name="ranker")
61
+ max_length = 50 ,
62
+ )
63
+ .add (uses = DiskIndexer , workspace = workdir )
103
64
)
104
65
105
66
with flow :
106
67
flow .protocol = "http"
107
- flow .port_expose = backend_port
68
+ flow .port_expose = port
108
69
flow .block ()
109
70
110
71
@@ -117,7 +78,7 @@ def query_restful():
117
78
@click .option ("--num_docs" , "-n" , default = max_docs )
118
79
@click .option ("--force" , "-f" , is_flag = True )
119
80
def main (task : str , num_docs : int , force : bool ):
120
- workspace = backend_workdir
81
+ workspace = workdir
121
82
if task == "index" :
122
83
if os .path .exists (workspace ):
123
84
if force :
0 commit comments