Skip to content

Commit

Permalink
Improve FeatureExtractionOutput, test for empty timings
Browse files Browse the repository at this point in the history
  • Loading branch information
RJKeevil committed Feb 9, 2024
1 parent be21388 commit b3e0c29
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 14 deletions.
23 changes: 14 additions & 9 deletions hugot_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,14 @@ func TestFeatureExtractionPipeline(t *testing.T) {
var expectedResults map[string][][]float32
Check(json.Unmarshal(resultsByte, &expectedResults))
var testResults [][]float32
var result [][]float32
var result pipelines.FeatureExtractionOutput
// test 'robert smith'

testResults = expectedResults["test1output"]
for i := 1; i <= 10; i++ {
result, err = pipeline.Run([]string{"robert smith"})
Check(err)
e := floatsEqual(result[0], testResults[0])
e := floatsEqual(result.Embeddings[0], testResults[0])
if e != nil {
t.Logf("Test 1: The neural network didn't produce the correct result on loop %d: %s\n", i, e)
t.FailNow()
Expand All @@ -174,7 +174,7 @@ func TestFeatureExtractionPipeline(t *testing.T) {
for i := 1; i <= 10; i++ {
result, err = pipeline.Run([]string{"robert smith junior", "francis ford coppola"})
Check(err)
for j, res := range result {
for j, res := range result.Embeddings {
e := floatsEqual(res, testResults[j])
if e != nil {
t.Logf("Test 2: The neural network didn't produce the correct result on loop %d: %s\n", i, e)
Expand All @@ -191,12 +191,12 @@ func TestFeatureExtractionPipeline(t *testing.T) {

for k, sentencePair := range testPairs {
// these vectors should be the same
firstRes, err := pipeline.Run(sentencePair[0])
Check(err)
firstEmbedding := firstRes[0]
secondRes, err := pipeline.Run(sentencePair[1])
Check(err)
secondEmbedding := secondRes[0]
firstRes, err2 := pipeline.Run(sentencePair[0])
Check(err2)
firstEmbedding := firstRes.Embeddings[0]
secondRes, err3 := pipeline.Run(sentencePair[1])
Check(err3)
secondEmbedding := secondRes.Embeddings[0]
e := floatsEqual(firstEmbedding, secondEmbedding)
if e != nil {
t.Logf("Equality failed for determinism test %s test with pairs %s and %s", k, strings.Join(sentencePair[0], ","), strings.Join(sentencePair[1], ","))
Expand All @@ -205,6 +205,11 @@ func TestFeatureExtractionPipeline(t *testing.T) {
t.Fail()
}
}

assert.Greater(t, pipeline.PipelineTimings.NumCalls, 0, "PipelineTimings.NumCalls should be greater than 0")
assert.Greater(t, pipeline.PipelineTimings.TotalNS, 0, "PipelineTimings.TotalNS should be greater than 0")
assert.Greater(t, pipeline.TokenizerTimings.NumCalls, 0, "TokenizerTimings.NumCalls should be greater than 0")
assert.Greater(t, pipeline.TokenizerTimings.TotalNS, 0, "TokenizerTimings.TotalNS should be greater than 0")
}

// utilities
Expand Down
8 changes: 5 additions & 3 deletions pipelines/featureExtraction.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ type FeatureExtractionPipeline struct {
BasePipeline
}

type FeatureExtractionOutput [][]float32
type FeatureExtractionOutput struct {
Embeddings [][]float32
}

// NewFeatureExtractionPipeline Initialize a feature extraction pipeline
func NewFeatureExtractionPipeline(modelPath string, name string) (*FeatureExtractionPipeline, error) {
Expand Down Expand Up @@ -75,7 +77,7 @@ func (p *FeatureExtractionPipeline) Postprocess(batch PipelineBatch) (FeatureExt
vectorCounter++
}
}
return outputs, nil
return FeatureExtractionOutput{Embeddings: outputs}, nil
}

func meanPooling(tokens [][]float32, input TokenizedInput, maxSequence int, dimensions int) []float32 {
Expand Down Expand Up @@ -103,7 +105,7 @@ func (p *FeatureExtractionPipeline) Run(inputs []string) (FeatureExtractionOutpu
batch := p.Preprocess(inputs)
batch, forwardError := p.Forward(batch)
if forwardError != nil {
return nil, forwardError
return FeatureExtractionOutput{}, forwardError
}
return p.Postprocess(batch)
}
4 changes: 2 additions & 2 deletions pipelines/pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ func (p *BasePipeline) Preprocess(inputs []string) PipelineBatch {
}
}

atomic.AddUint64(&p.PipelineTimings.NumCalls, 1)
atomic.AddUint64(&p.PipelineTimings.TotalNS, uint64(time.Since(start)))
atomic.AddUint64(&p.TokenizerTimings.NumCalls, 1)
atomic.AddUint64(&p.TokenizerTimings.TotalNS, uint64(time.Since(start)))
batch := p.convertInputToTensors(outputs, maxSequence+1)
return batch
}
Expand Down

0 comments on commit b3e0c29

Please sign in to comment.