-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate_embeddings.py
41 lines (28 loc) · 1.36 KB
/
generate_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import pandas as pd
import os
import yaml
from tqdm import tqdm # For progress bar
import torch
from concurrent.futures import ThreadPoolExecutor
from sentence_transformers import SentenceTransformer
import numpy as np
# Open the configuration file and load the different arguments
with open('config.yaml') as f:
config = yaml.safe_load(f)
# truncate_dim=256
model = SentenceTransformer(f'{config['embedding_model']}')
# Load the DataFrame from a Parquet file
df = pd.read_parquet('tce.parquet')
# Ensure the 'historico' and 'idContrato' are paired with their indices
data = df['Historico'].astype(str).tolist() # Ensure it's a list of strings
def encode_batch(batch): # Encode each batch, return a list of embeddings
texts = [item[0] for item in batch] # Get the 'historico' text
return model.encode(texts)
# Split data into batches
batch_size = config['batch_size']
batches = [data[i:i + batch_size] for i in range(0, len(data), batch_size)]
# Parallel processing with threads
with ThreadPoolExecutor(max_workers=4) as executor: # Adjust max_workers as needed
for i, result in enumerate(tqdm(executor.map(encode_batch, batches), total=len(batches))):
# Save the combined embeddings with indices and contract IDs to a .npy file
np.save(os.path.join(f"{config['output_embeddings']}/embeddings_batch_{i}.npy"), result)