@@ -835,8 +835,8 @@ def load_project_version(self, version_id: str) -> Project:
835835 >>> version.wait_for_completion()
836836 >>> version.print_goal_report()
837837
838- With the :obj:`project_versions.ProjectVersion` object loaded, you are able to check progress and
839- goal statuses.
838+ With the :obj:`project_versions.ProjectVersion` object loaded, you are able to
839+ check progress and goal statuses.
840840 """
841841 endpoint = f"versions/{ version_id } "
842842 version_data = self .api .get_request (endpoint )
@@ -896,11 +896,17 @@ def create_inference_pipeline(
896896 " creating it." ,
897897 ) from None
898898
899- # Validate reference dataset and augment config
899+ # Load dataset config
900900 if reference_dataset_config_file_path is not None :
901+ reference_dataset_config = utils .read_yaml (
902+ reference_dataset_config_file_path
903+ )
904+
905+ if reference_dataset_config is not None :
906+ # Validate reference dataset and augment config
901907 dataset_validator = dataset_validators .get_validator (
902908 task_type = task_type ,
903- dataset_config_file_path = reference_dataset_config_file_path ,
909+ dataset_config = reference_dataset_config ,
904910 dataset_df = reference_df ,
905911 )
906912 failed_validations = dataset_validator .validate ()
@@ -912,40 +918,39 @@ def create_inference_pipeline(
912918 " upload." ,
913919 ) from None
914920
915- # Load dataset config and augment with defaults
916- reference_dataset_config = utils .read_yaml (
917- reference_dataset_config_file_path
918- )
919921 reference_dataset_data = DatasetSchema ().load (
920922 {"task_type" : task_type .value , ** reference_dataset_config }
921923 )
922924
923- with tempfile .TemporaryDirectory () as tmp_dir :
924925 # Copy relevant files to tmp dir if reference dataset is provided
925- if reference_dataset_config_file_path is not None :
926+ with tempfile . TemporaryDirectory () as tmp_dir :
926927 utils .write_yaml (
927928 reference_dataset_data , f"{ tmp_dir } /dataset_config.yaml"
928929 )
929930 if reference_df is not None :
930931 reference_df .to_csv (f"{ tmp_dir } /dataset.csv" , index = False )
931932 else :
932933 shutil .copy (
933- reference_dataset_file_path ,
934- f"{ tmp_dir } /dataset.csv" ,
934+ reference_dataset_file_path , f"{ tmp_dir } /dataset.csv"
935935 )
936936
937- tar_file_path = os .path .join (tmp_dir , "tarfile" )
938- with tarfile .open (tar_file_path , mode = "w:gz" ) as tar :
939- tar .add (tmp_dir , arcname = os .path .basename ("reference_dataset" ))
940-
937+ tar_file_path = os .path .join (tmp_dir , "tarfile" )
938+ with tarfile .open (tar_file_path , mode = "w:gz" ) as tar :
939+ tar .add (tmp_dir , arcname = os .path .basename ("reference_dataset" ))
940+
941+ endpoint = f"projects/{ project_id } /inference-pipelines"
942+ inference_pipeline_data = self .api .upload (
943+ endpoint = endpoint ,
944+ file_path = tar_file_path ,
945+ object_name = "tarfile" ,
946+ body = inference_pipeline_config ,
947+ storage_uri_key = "referenceDatasetUri" ,
948+ method = "POST" ,
949+ )
950+ else :
941951 endpoint = f"projects/{ project_id } /inference-pipelines"
942- inference_pipeline_data = self .api .upload (
943- endpoint = endpoint ,
944- file_path = tar_file_path ,
945- object_name = "tarfile" ,
946- body = inference_pipeline_config ,
947- storage_uri_key = "referenceDatasetUri" ,
948- method = "POST" ,
952+ inference_pipeline_data = self .api .post_request (
953+ endpoint = endpoint , body = inference_pipeline_config
949954 )
950955 inference_pipeline = InferencePipeline (
951956 inference_pipeline_data , self .api .upload , self , task_type
0 commit comments