diff --git a/hugot_test.go b/hugot_test.go index 13c2acc..3b5f68b 100644 --- a/hugot_test.go +++ b/hugot_test.go @@ -155,14 +155,14 @@ func TestFeatureExtractionPipeline(t *testing.T) { var expectedResults map[string][][]float32 Check(json.Unmarshal(resultsByte, &expectedResults)) var testResults [][]float32 - var result [][]float32 + var result pipelines.FeatureExtractionOutput // test 'robert smith' testResults = expectedResults["test1output"] for i := 1; i <= 10; i++ { result, err = pipeline.Run([]string{"robert smith"}) Check(err) - e := floatsEqual(result[0], testResults[0]) + e := floatsEqual(result.Embeddings[0], testResults[0]) if e != nil { t.Logf("Test 1: The neural network didn't produce the correct result on loop %d: %s\n", i, e) t.FailNow() @@ -174,7 +174,7 @@ func TestFeatureExtractionPipeline(t *testing.T) { for i := 1; i <= 10; i++ { result, err = pipeline.Run([]string{"robert smith junior", "francis ford coppola"}) Check(err) - for j, res := range result { + for j, res := range result.Embeddings { e := floatsEqual(res, testResults[j]) if e != nil { t.Logf("Test 2: The neural network didn't produce the correct result on loop %d: %s\n", i, e) @@ -191,12 +191,12 @@ func TestFeatureExtractionPipeline(t *testing.T) { for k, sentencePair := range testPairs { // these vectors should be the same - firstRes, err := pipeline.Run(sentencePair[0]) - Check(err) - firstEmbedding := firstRes[0] - secondRes, err := pipeline.Run(sentencePair[1]) - Check(err) - secondEmbedding := secondRes[0] + firstRes, err2 := pipeline.Run(sentencePair[0]) + Check(err2) + firstEmbedding := firstRes.Embeddings[0] + secondRes, err3 := pipeline.Run(sentencePair[1]) + Check(err3) + secondEmbedding := secondRes.Embeddings[0] e := floatsEqual(firstEmbedding, secondEmbedding) if e != nil { t.Logf("Equality failed for determinism test %s test with pairs %s and %s", k, strings.Join(sentencePair[0], ","), strings.Join(sentencePair[1], ",")) @@ -205,6 +205,11 @@ func TestFeatureExtractionPipeline(t *testing.T) { t.Fail() } } + + assert.Greater(t, pipeline.PipelineTimings.NumCalls, 0, "PipelineTimings.NumCalls should be greater than 0") + assert.Greater(t, pipeline.PipelineTimings.TotalNS, 0, "PipelineTimings.TotalNS should be greater than 0") + assert.Greater(t, pipeline.TokenizerTimings.NumCalls, 0, "TokenizerTimings.NumCalls should be greater than 0") + assert.Greater(t, pipeline.TokenizerTimings.TotalNS, 0, "TokenizerTimings.TotalNS should be greater than 0") } // utilities diff --git a/pipelines/featureExtraction.go b/pipelines/featureExtraction.go index e562b46..89eb7d8 100644 --- a/pipelines/featureExtraction.go +++ b/pipelines/featureExtraction.go @@ -15,7 +15,9 @@ type FeatureExtractionPipeline struct { BasePipeline } -type FeatureExtractionOutput [][]float32 +type FeatureExtractionOutput struct { + Embeddings [][]float32 +} // NewFeatureExtractionPipeline Initialize a feature extraction pipeline func NewFeatureExtractionPipeline(modelPath string, name string) (*FeatureExtractionPipeline, error) { @@ -75,7 +77,7 @@ func (p *FeatureExtractionPipeline) Postprocess(batch PipelineBatch) (FeatureExt vectorCounter++ } } - return outputs, nil + return FeatureExtractionOutput{Embeddings: outputs}, nil } func meanPooling(tokens [][]float32, input TokenizedInput, maxSequence int, dimensions int) []float32 { @@ -103,7 +105,7 @@ func (p *FeatureExtractionPipeline) Run(inputs []string) (FeatureExtractionOutpu batch := p.Preprocess(inputs) batch, forwardError := p.Forward(batch) if forwardError != nil { - return nil, forwardError + return FeatureExtractionOutput{}, forwardError } return p.Postprocess(batch) } diff --git a/pipelines/pipeline.go b/pipelines/pipeline.go index 77ae9cc..e2d2b5c 100644 --- a/pipelines/pipeline.go +++ b/pipelines/pipeline.go @@ -196,8 +196,8 @@ func (p *BasePipeline) Preprocess(inputs []string) PipelineBatch { } } - atomic.AddUint64(&p.PipelineTimings.NumCalls, 1) - atomic.AddUint64(&p.PipelineTimings.TotalNS, uint64(time.Since(start))) + atomic.AddUint64(&p.TokenizerTimings.NumCalls, 1) + atomic.AddUint64(&p.TokenizerTimings.TotalNS, uint64(time.Since(start))) batch := p.convertInputToTensors(outputs, maxSequence+1) return batch }