@@ -69,9 +69,9 @@ def compress(self, tensor_key, data, require_lossless=False, **kwargs):
6969 metadata: metadata associated with compressed tensor.
7070 """
7171 if require_lossless :
72- compressed_nparray , metadata = self .lossless_pipeline .forward (data , ** kwargs )
72+ data , metadata = self .lossless_pipeline .forward (data , ** kwargs )
7373 else :
74- compressed_nparray , metadata = self .compression_pipeline .forward (data , ** kwargs )
74+ data , metadata = self .compression_pipeline .forward (data , ** kwargs )
7575 # Define the compressed tensorkey that should be
7676 # returned ('trained.delta'->'trained.delta.lossy_compressed')
7777 tensor_name , origin , round_number , report , tags = tensor_key
@@ -80,7 +80,7 @@ def compress(self, tensor_key, data, require_lossless=False, **kwargs):
8080 else :
8181 new_tags = change_tags (tags , add_field = "lossy_compressed" )
8282 compressed_tensor_key = TensorKey (tensor_name , origin , round_number , report , new_tags )
83- return compressed_tensor_key , compressed_nparray , metadata
83+ return compressed_tensor_key , data , metadata
8484
8585 def decompress (
8686 self ,
@@ -121,13 +121,9 @@ def decompress(
121121 assert "compressed" in tags , "Cannot losslessly decompress lossy tensor"
122122
123123 if require_lossless or "compressed" in tags :
124- decompressed_nparray = self .lossless_pipeline .backward (
125- data , transformer_metadata , ** kwargs
126- )
124+ data = self .lossless_pipeline .backward (data , transformer_metadata , ** kwargs )
127125 else :
128- decompressed_nparray = self .compression_pipeline .backward (
129- data , transformer_metadata , ** kwargs
130- )
126+ data = self .compression_pipeline .backward (data , transformer_metadata , ** kwargs )
131127 # Define the decompressed tensorkey that should be returned
132128 if "lossy_compressed" in tags :
133129 new_tags = change_tags (
@@ -144,7 +140,7 @@ def decompress(
144140 else :
145141 raise NotImplementedError ("Decompression is only supported on compressed data" )
146142
147- return decompressed_tensor_key , decompressed_nparray
143+ return decompressed_tensor_key , data
148144
149145 @staticmethod
150146 def generate_delta (tensor_key , nparray , base_model_nparray ):
0 commit comments