10
10
from collections .abc import Iterator
11
11
from contextvars import ContextVar
12
12
from datetime import datetime
13
+ from threading import Lock
13
14
from typing import Callable , Dict , List , Optional
14
15
15
- from opentelemetry import trace
16
+ import opentelemetry . trace as otel_trace
16
17
from opentelemetry .trace .status import StatusCode
17
18
18
19
from promptflow ._core .generator_proxy import GeneratorProxy , generate_from_proxy
24
25
25
26
from .thread_local_singleton import ThreadLocalSingleton
26
27
27
- open_telemetry_tracer = trace .get_tracer ("promptflow" )
28
+
29
+ open_telemetry_tracer = otel_trace .get_tracer ("promptflow" )
28
30
29
31
30
32
class Tracer (ThreadLocalSingleton ):
@@ -153,6 +155,46 @@ def _format_error(error: Exception) -> dict:
153
155
}
154
156
155
157
158
+ class TokenCollector ():
159
+ _lock = Lock ()
160
+
161
+ def __init__ (self ):
162
+ self ._span_id_to_tokens = {}
163
+
164
+ def collect_openai_tokens (self , span , output ):
165
+ span_id = span .get_span_context ().span_id
166
+ if not inspect .isgenerator (output ) and hasattr (output , "usage" ) and output .usage is not None :
167
+ tokens = {
168
+ f"__computed__.cumulative_token_count.{ k .split ('_' )[0 ]} " : v for k , v in output .usage .dict ().items ()
169
+ }
170
+ if tokens :
171
+ with self ._lock :
172
+ self ._span_id_to_tokens [span_id ] = tokens
173
+
174
+ def collect_openai_tokens_for_parent_span (self , span ):
175
+ tokens = self .try_get_openai_tokens (span .get_span_context ().span_id )
176
+ if tokens :
177
+ if not hasattr (span , "parent" ) or span .parent is None :
178
+ return
179
+ parent_span_id = span .parent .span_id
180
+ with self ._lock :
181
+ if parent_span_id in self ._span_id_to_tokens :
182
+ merged_tokens = {
183
+ key : self ._span_id_to_tokens [parent_span_id ].get (key , 0 ) + tokens .get (key , 0 )
184
+ for key in set (self ._span_id_to_tokens [parent_span_id ]) | set (tokens )
185
+ }
186
+ self ._span_id_to_tokens [parent_span_id ] = merged_tokens
187
+ else :
188
+ self ._span_id_to_tokens [parent_span_id ] = tokens
189
+
190
+ def try_get_openai_tokens (self , span_id ):
191
+ with self ._lock :
192
+ return self ._span_id_to_tokens .get (span_id , None )
193
+
194
+
195
+ token_collector = TokenCollector ()
196
+
197
+
156
198
def _create_trace_from_function_call (
157
199
f , * , args = None , kwargs = None , args_to_ignore : Optional [List [str ]] = None , trace_type = TraceType .FUNCTION
158
200
):
@@ -205,6 +247,14 @@ def get_node_name_from_context():
205
247
return None
206
248
207
249
250
+ def enrich_span_with_context (span ):
251
+ try :
252
+ attrs_from_context = OperationContext .get_instance ()._get_otel_attributes ()
253
+ span .set_attributes (attrs_from_context )
254
+ except Exception as e :
255
+ logging .warning (f"Failed to enrich span with context: { e } " )
256
+
257
+
208
258
def enrich_span_with_trace (span , trace ):
209
259
try :
210
260
span .set_attributes (
@@ -215,8 +265,7 @@ def enrich_span_with_trace(span, trace):
215
265
"node_name" : get_node_name_from_context (),
216
266
}
217
267
)
218
- attrs_from_context = OperationContext .get_instance ()._get_otel_attributes ()
219
- span .set_attributes (attrs_from_context )
268
+ enrich_span_with_context (span )
220
269
except Exception as e :
221
270
logging .warning (f"Failed to enrich span with trace: { e } " )
222
271
@@ -235,6 +284,9 @@ def enrich_span_with_output(span, output):
235
284
try :
236
285
serialized_output = serialize_attribute (output )
237
286
span .set_attribute ("output" , serialized_output )
287
+ tokens = token_collector .try_get_openai_tokens (span .get_span_context ().span_id )
288
+ if tokens :
289
+ span .set_attributes (tokens )
238
290
except Exception as e :
239
291
logging .warning (f"Failed to enrich span with output: { e } " )
240
292
@@ -306,12 +358,16 @@ async def wrapped(*args, **kwargs):
306
358
Tracer .push (trace )
307
359
enrich_span_with_input (span , trace .inputs )
308
360
output = await func (* args , ** kwargs )
361
+ if trace_type == TraceType .LLM :
362
+ token_collector .collect_openai_tokens (span , output )
309
363
enrich_span_with_output (span , output )
310
364
span .set_status (StatusCode .OK )
311
- return Tracer .pop (output )
365
+ output = Tracer .pop (output )
312
366
except Exception as e :
313
367
Tracer .pop (None , e )
314
368
raise
369
+ token_collector .collect_openai_tokens_for_parent_span (span )
370
+ return output
315
371
316
372
wrapped .__original_function = func
317
373
@@ -351,12 +407,16 @@ def wrapped(*args, **kwargs):
351
407
Tracer .push (trace )
352
408
enrich_span_with_input (span , trace .inputs )
353
409
output = func (* args , ** kwargs )
410
+ if trace_type == TraceType .LLM :
411
+ token_collector .collect_openai_tokens (span , output )
354
412
enrich_span_with_output (span , output )
355
413
span .set_status (StatusCode .OK )
356
- return Tracer .pop (output )
414
+ output = Tracer .pop (output )
357
415
except Exception as e :
358
416
Tracer .pop (None , e )
359
417
raise
418
+ token_collector .collect_openai_tokens_for_parent_span (span )
419
+ return output
360
420
361
421
wrapped .__original_function = func
362
422
0 commit comments