Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions lib/Doctrine/ODM/MongoDB/Aggregation/Stage/VectorSearch.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
<?php

declare(strict_types=1);

namespace Doctrine\ODM\MongoDB\Aggregation\Stage;

use Doctrine\ODM\MongoDB\Aggregation\Builder;
use Doctrine\ODM\MongoDB\Aggregation\Stage;
use Doctrine\ODM\MongoDB\Query\Expr;
use MongoDB\BSON\Binary;

Check failure on line 10 in lib/Doctrine/ODM/MongoDB/Aggregation/Stage/VectorSearch.php

View workflow job for this annotation

GitHub Actions / Coding Standards / Coding Standards (PHP: 8.4)

Type MongoDB\BSON\Binary is not used in this file.
use MongoDB\BSON\Decimal128;
use MongoDB\BSON\Int64;

/**
* @phpstan-type Vector list<int|Int64>|list<float|Decimal128>|list<bool|0|1>
* @phpstan-type VectorSearchStageExpression array{
* '$vectorSearch': object{
* exact?: bool,
* filter?: object,
* index?: string,
* limit?: int,
* numCandidates?: int,
* path?: string,
* queryVector?: Vector,
* }
* }
*/
class VectorSearch extends Stage
{
private ?bool $exact = null;
private ?Expr $filter = null;
private ?string $index = null;
private ?int $limit = null;
private ?int $numCandidates = null;
private ?string $path = null;
/** @phpstan-var Vector|null */
private ?array $queryVector = null;

public function __construct(Builder $builder)
{
parent::__construct($builder);
}

public function getExpression(): array
{
$params = [];

if ($this->exact !== null) {
$params['exact'] = $this->exact;
}

if ($this->filter !== null) {
$params['filter'] = $this->filter->getQuery();
}

if ($this->index !== null) {
$params['index'] = $this->index;
}

if ($this->limit !== null) {
$params['limit'] = $this->limit;
}

if ($this->numCandidates !== null) {
$params['numCandidates'] = $this->numCandidates;
}

if ($this->path !== null) {
$params['path'] = $this->path;
}
Comment on lines +68 to +70
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The path should be mapped to the field name using the class metadata.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a TODO item? If so, is there another ticket to track this?

This seems related to #2820 (comment) from the PR that introduced a VectorSearchIndex mapping.


if ($this->queryVector !== null) {
$params['queryVector'] = $this->queryVector;
}
Comment on lines +48 to +74
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

index, limit, path and queryVector are required. The server will return an error if they are not used.
Looking at the other stages, it seems that we don't check this before.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to me. Maybe it's nice to document this as a standard in some sort of development decision documentation in the repo, as I think we've had conversations about this before.


return [$this->getStageName() => $params];
}

public function exact(bool $exact): static
{
$this->exact = $exact;

return $this;
}

public function filter(Expr $filter): static
{
$this->filter = $filter;

return $this;
}

public function index(string $index): static
{
$this->index = $index;

return $this;
}

public function limit(int $limit): static
{
$this->limit = $limit;

return $this;
}

public function numCandidates(int $numCandidates): static
{
$this->numCandidates = $numCandidates;

return $this;
}

public function path(string $path): static
{
$this->path = $path;

return $this;
}

/** @phpstan-param Vector $queryVector */
public function queryVector(array $queryVector): static
{
$this->queryVector = $queryVector;

return $this;
}

protected function getStageName(): string
{
return '$vectorSearch';
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
<?php

declare(strict_types=1);

namespace Doctrine\ODM\MongoDB\Tests\Aggregation\Stage;

use Doctrine\ODM\MongoDB\Aggregation\Stage\VectorSearch;
use Doctrine\ODM\MongoDB\Tests\Aggregation\AggregationTestTrait;
use Doctrine\ODM\MongoDB\Tests\BaseTestCase;

class VectorSearchTest extends BaseTestCase
{
use AggregationTestTrait;

public function testEmptyStage(): void
{
$stage = new VectorSearch($this->getTestAggregationBuilder());
self::assertSame(['$vectorSearch' => []], $stage->getExpression());
}

public function testExact(): void
{
$stage = new VectorSearch($this->getTestAggregationBuilder());
$stage->exact(true);
self::assertSame(['$vectorSearch' => ['exact' => true]], $stage->getExpression());
}

public function testFilter(): void
{
$builder = $this->getTestAggregationBuilder();
$stage = new VectorSearch($builder);
$stage->filter($builder->matchExpr()->field('status')->notEqual('inactive'));
self::assertSame(['$vectorSearch' => ['filter' => ['status' => ['$ne' => 'inactive']]]], $stage->getExpression());
}

public function testIndex(): void
{
$stage = new VectorSearch($this->getTestAggregationBuilder());
$stage->index('myIndex');
self::assertSame(['$vectorSearch' => ['index' => 'myIndex']], $stage->getExpression());
}

public function testLimit(): void
{
$stage = new VectorSearch($this->getTestAggregationBuilder());
$stage->limit(10);
self::assertSame(['$vectorSearch' => ['limit' => 10]], $stage->getExpression());
}

public function testNumCandidates(): void
{
$stage = new VectorSearch($this->getTestAggregationBuilder());
$stage->numCandidates(5);
self::assertSame(['$vectorSearch' => ['numCandidates' => 5]], $stage->getExpression());
}

public function testPath(): void
{
$stage = new VectorSearch($this->getTestAggregationBuilder());
$stage->path('vectorField');
self::assertSame(['$vectorSearch' => ['path' => 'vectorField']], $stage->getExpression());
}

public function testQueryVector(): void
{
$stage = new VectorSearch($this->getTestAggregationBuilder());
$stage->queryVector([1, 2, 3]);
self::assertSame(['$vectorSearch' => ['queryVector' => [1, 2, 3]]], $stage->getExpression());
}

public function testChainingAllOptions(): void
{
$builder = $this->getTestAggregationBuilder();
$stage = (new VectorSearch($builder))
->exact(false)
->filter($builder->matchExpr()->field('status')->notEqual('inactive'))
->index('idx')
->limit(7)
->numCandidates(3)
->path('vec')
->queryVector([0.1, 0.2]);
self::assertSame([
'$vectorSearch' => [
'exact' => false,
'filter' => ['status' => ['$ne' => 'inactive']],
'index' => 'idx',
'limit' => 7,
'numCandidates' => 3,
'path' => 'vec',
'queryVector' => [0.1, 0.2],
],
], $stage->getExpression());
}
}
Loading