@@ -741,61 +741,97 @@ def post_task_results():
741741 def _setup_tensor_route (self ):
742742 """Set up the /tensors/aggregated endpoint."""
743743
744- @self .app .route (f"/{ self .api_prefix } /tensors/aggregated" , methods = ["GET " ])
745- def get_aggregated_tensor ():
746- """Endpoint for collaborators to retrieve an aggregated tensor ."""
744+ @self .app .route (f"/{ self .api_prefix } /tensors/aggregated/batch " , methods = ["POST " ])
745+ def get_aggregated_tensors ():
746+ """Endpoint for collaborators to retrieve multiple aggregated tensors ."""
747747 start_time = time .time ()
748748
749749 # Validate that this endpoint is not used in connector mode
750750 if self .use_connector :
751- abort (501 , "GetAggregatedTensor not supported in connector mode" )
751+ abort (501 , "GetAggregatedTensors not supported in connector mode" )
752752
753- # Get and validate collaborator identity
754- collaborator_id = request .args .get ("collaborator_id" )
755- federation_id = request .args .get ("federation_uuid" )
753+ try :
754+ # Parse the incoming JSON to a GetAggregatedTensorsRequest protobuf message
755+ request_data = request .get_json ()
756+ if not request_data :
757+ abort (400 , "Invalid JSON payload" )
756758
757- # Use the consolidated validation method
758- self . _is_authorized ( collaborator_id , federation_id )
759+ tensors_request = aggregator_pb2 . GetAggregatedTensorsRequest ()
760+ json_format . ParseDict ( request_data , tensors_request , ignore_unknown_fields = True )
759761
760- # Extract tensor request parameters
761- tensor_name = request .args .get ("tensor_name" )
762- try :
763- round_number = int (request .args .get ("round_number" , 0 ))
764- except (TypeError , ValueError ):
765- abort (400 , "Invalid round number" )
766- report = request .args .get ("report" , "" ).lower () == "true"
767- tags = request .args .getlist ("tags" )
768- require_lossless = request .args .get ("require_lossless" , "" ).lower () == "true"
769-
770- # Get the tensor from aggregator - direct delegation to the aggregator
771- named_tensor = self .aggregator .get_aggregated_tensor (
772- tensor_name ,
773- round_number ,
774- report = report ,
775- tags = tuple (tags ),
776- require_lossless = require_lossless ,
777- requested_by = collaborator_id ,
778- )
762+ # Validate headers and get collaborator identity
763+ collaborator_id = tensors_request .header .sender
764+ federation_id = tensors_request .header .federation_uuid
779765
780- # Create response header using the standardized method
781- header = create_header (
782- sender = str (self .aggregator .uuid ),
783- receiver = collaborator_id ,
784- federation_uuid = str (self .aggregator .federation_uuid ),
785- single_col_cert_common_name = self .aggregator .single_col_cert_common_name or "" ,
786- )
766+ # Use the consolidated validation method
767+ self ._is_authorized (collaborator_id , federation_id )
787768
788- # Create response with empty tensor if not found
789- response_proto = aggregator_pb2 .GetAggregatedTensorResponse (
790- header = header ,
791- round_number = round_number ,
792- tensor = named_tensor
793- if named_tensor is not None
794- else aggregator_pb2 .NamedTensorProto (),
795- )
769+ # Validate request header similar to gRPC implementation
770+ assert tensors_request .header .receiver == str (self .aggregator .uuid ), (
771+ f"Header receiver mismatch. Expected: { self .aggregator .uuid } , "
772+ f"Got: { tensors_request .header .receiver } "
773+ )
796774
797- logger .debug (f"Tensor retrieval completed in { time .time () - start_time :.2f} seconds" )
798- return jsonify (json_format .MessageToDict (response_proto ))
775+ assert tensors_request .header .federation_uuid == str (
776+ self .aggregator .federation_uuid
777+ ), (
778+ f"Federation UUID mismatch. Expected: { self .aggregator .federation_uuid } , "
779+ f"Got: { tensors_request .header .federation_uuid } "
780+ )
781+
782+ expected_cn = self .aggregator .single_col_cert_common_name or ""
783+ assert tensors_request .header .single_col_cert_common_name == expected_cn , (
784+ f"Single col cert CN mismatch. Expected: { expected_cn } , "
785+ f"Got: { tensors_request .header .single_col_cert_common_name } "
786+ )
787+
788+ # Get tensors from aggregator - similar to gRPC implementation
789+ logger .debug (
790+ f"Processing batch request for { len (tensors_request .tensor_specs )} tensors"
791+ )
792+
793+ named_tensors = []
794+ for ts in tensors_request .tensor_specs :
795+ named_tensor = self .aggregator .get_aggregated_tensor (
796+ ts .tensor_name ,
797+ ts .round_number ,
798+ ts .report ,
799+ tuple (ts .tags ),
800+ ts .require_lossless ,
801+ collaborator_id ,
802+ )
803+ # Add tensor to list (None tensors will be handled by the client)
804+ if named_tensor is not None :
805+ named_tensors .append (named_tensor )
806+ else :
807+ # Add empty tensor placeholder to maintain order
808+ named_tensors .append (aggregator_pb2 .NamedTensorProto ())
809+
810+ # Create response header using the standardized method
811+ header = create_header (
812+ sender = str (self .aggregator .uuid ),
813+ receiver = collaborator_id ,
814+ federation_uuid = str (self .aggregator .federation_uuid ),
815+ single_col_cert_common_name = self .aggregator .single_col_cert_common_name or "" ,
816+ )
817+
818+ # Create response
819+ response_proto = aggregator_pb2 .GetAggregatedTensorsResponse (
820+ header = header , tensors = named_tensors
821+ )
822+
823+ logger .debug (
824+ f"Batch tensor retrieval completed in { time .time () - start_time :.2f} seconds. "
825+ f"Returned { len (named_tensors )} tensors"
826+ )
827+ return jsonify (json_format .MessageToDict (response_proto ))
828+
829+ except AssertionError as e :
830+ logger .error (f"Header validation failed: { str (e )} " )
831+ abort (400 , str (e ))
832+ except Exception as e :
833+ logger .error (f"Error processing batch tensor request: { str (e )} " )
834+ abort (400 , f"Error processing batch tensor request: { str (e )} " )
799835
800836 def _setup_relay_route (self ):
801837 """Set up the /interop/relay endpoint."""
0 commit comments