mirror of
https://github.com/immich-app/immich.git
synced 2025-06-14 21:38:26 +02:00
feat: vectorchord (#18042)
* wip auto-detect available extensions auto-recovery, fix reindexing check use original image for ml * set probes * update image for sql checker update images for gha * cascade * fix new instance * accurate dummy vector * simplify dummy * preexisiting pg docs * handle different db name * maybe fix sql generation * revert refreshfaces sql change * redundant switch * outdated message * update docker compose files * Update docs/docs/administration/postgres-standalone.md Co-authored-by: Daniel Dietzler <36593685+danieldietzler@users.noreply.github.com> * tighten range * avoid always printing "vector reindexing complete" * remove nesting * use new images * add vchord to unit tests * debug e2e image * mention 1.107.2 in startup error * support new vchord versions --------- Co-authored-by: Daniel Dietzler <36593685+danieldietzler@users.noreply.github.com>
This commit is contained in:
parent
fe71894308
commit
0d773af6c3
35 changed files with 572 additions and 444 deletions
server
src
constants.tsdecorators.ts
dtos
enum.tsmigrations
1700713871511-UsePgVectors.ts1700713994428-AddCLIPEmbeddingIndex.ts1700714033632-AddFaceEmbeddingIndex.ts1718486162779-AddFaceSearchRelation.ts
queries
repositories
config.repository.spec.tsconfig.repository.tsdatabase.repository.tsperson.repository.tssearch.repository.ts
schema/migrations
services
database.service.spec.tsdatabase.service.tsperson.service.tssmart-info.service.spec.tssmart-info.service.ts
types.tsutils
test
|
@ -1,9 +1,10 @@
|
|||
import { Duration } from 'luxon';
|
||||
import { readFileSync } from 'node:fs';
|
||||
import { SemVer } from 'semver';
|
||||
import { DatabaseExtension, ExifOrientation } from 'src/enum';
|
||||
import { DatabaseExtension, ExifOrientation, VectorIndex } from 'src/enum';
|
||||
|
||||
export const POSTGRES_VERSION_RANGE = '>=14.0.0';
|
||||
export const VECTORCHORD_VERSION_RANGE = '>=0.3 <1';
|
||||
export const VECTORS_VERSION_RANGE = '>=0.2 <0.4';
|
||||
export const VECTOR_VERSION_RANGE = '>=0.5 <1';
|
||||
|
||||
|
@ -20,8 +21,22 @@ export const EXTENSION_NAMES: Record<DatabaseExtension, string> = {
|
|||
earthdistance: 'earthdistance',
|
||||
vector: 'pgvector',
|
||||
vectors: 'pgvecto.rs',
|
||||
vchord: 'VectorChord',
|
||||
} as const;
|
||||
|
||||
export const VECTOR_EXTENSIONS = [
|
||||
DatabaseExtension.VECTORCHORD,
|
||||
DatabaseExtension.VECTORS,
|
||||
DatabaseExtension.VECTOR,
|
||||
] as const;
|
||||
|
||||
export const VECTOR_INDEX_TABLES = {
|
||||
[VectorIndex.CLIP]: 'smart_search',
|
||||
[VectorIndex.FACE]: 'face_search',
|
||||
} as const;
|
||||
|
||||
export const VECTORCHORD_LIST_SLACK_FACTOR = 1.2;
|
||||
|
||||
export const SALT_ROUNDS = 10;
|
||||
|
||||
export const IWorker = 'IWorker';
|
||||
|
|
|
@ -116,7 +116,7 @@ export const DummyValue = {
|
|||
DATE: new Date(),
|
||||
TIME_BUCKET: '2024-01-01T00:00:00.000Z',
|
||||
BOOLEAN: true,
|
||||
VECTOR: '[1, 2, 3]',
|
||||
VECTOR: JSON.stringify(Array.from({ length: 512 }, () => 0)),
|
||||
};
|
||||
|
||||
export const GENERATE_SQL_KEY = 'generate-sql-key';
|
||||
|
|
|
@ -154,9 +154,9 @@ export class EnvDto {
|
|||
@Optional()
|
||||
DB_USERNAME?: string;
|
||||
|
||||
@IsEnum(['pgvector', 'pgvecto.rs'])
|
||||
@IsEnum(['pgvector', 'pgvecto.rs', 'vectorchord'])
|
||||
@Optional()
|
||||
DB_VECTOR_EXTENSION?: 'pgvector' | 'pgvecto.rs';
|
||||
DB_VECTOR_EXTENSION?: 'pgvector' | 'pgvecto.rs' | 'vectorchord';
|
||||
|
||||
@IsString()
|
||||
@Optional()
|
||||
|
|
|
@ -414,6 +414,7 @@ export enum DatabaseExtension {
|
|||
EARTH_DISTANCE = 'earthdistance',
|
||||
VECTOR = 'vector',
|
||||
VECTORS = 'vectors',
|
||||
VECTORCHORD = 'vchord',
|
||||
}
|
||||
|
||||
export enum BootstrapEventPriority {
|
||||
|
|
|
@ -1,15 +1,13 @@
|
|||
import { ConfigRepository } from 'src/repositories/config.repository';
|
||||
import { getVectorExtension } from 'src/repositories/database.repository';
|
||||
import { getCLIPModelInfo } from 'src/utils/misc';
|
||||
import { MigrationInterface, QueryRunner } from 'typeorm';
|
||||
|
||||
const vectorExtension = new ConfigRepository().getEnv().database.vectorExtension;
|
||||
|
||||
export class UsePgVectors1700713871511 implements MigrationInterface {
|
||||
name = 'UsePgVectors1700713871511';
|
||||
|
||||
public async up(queryRunner: QueryRunner): Promise<void> {
|
||||
await queryRunner.query(`SET search_path TO "$user", public, vectors`);
|
||||
await queryRunner.query(`CREATE EXTENSION IF NOT EXISTS ${vectorExtension}`);
|
||||
await queryRunner.query(`CREATE EXTENSION IF NOT EXISTS ${await getVectorExtension(queryRunner)}`);
|
||||
const faceDimQuery = await queryRunner.query(`
|
||||
SELECT CARDINALITY(embedding::real[]) as dimsize
|
||||
FROM asset_faces
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
import { ConfigRepository } from 'src/repositories/config.repository';
|
||||
import { getVectorExtension } from 'src/repositories/database.repository';
|
||||
import { vectorIndexQuery } from 'src/utils/database';
|
||||
import { MigrationInterface, QueryRunner } from 'typeorm';
|
||||
|
||||
const vectorExtension = new ConfigRepository().getEnv().database.vectorExtension;
|
||||
|
||||
export class AddCLIPEmbeddingIndex1700713994428 implements MigrationInterface {
|
||||
name = 'AddCLIPEmbeddingIndex1700713994428';
|
||||
|
||||
public async up(queryRunner: QueryRunner): Promise<void> {
|
||||
const vectorExtension = await getVectorExtension(queryRunner);
|
||||
await queryRunner.query(`SET search_path TO "$user", public, vectors`);
|
||||
|
||||
await queryRunner.query(vectorIndexQuery({ vectorExtension, table: 'smart_search', indexName: 'clip_index' }));
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
import { ConfigRepository } from 'src/repositories/config.repository';
|
||||
import { getVectorExtension } from 'src/repositories/database.repository';
|
||||
import { vectorIndexQuery } from 'src/utils/database';
|
||||
import { MigrationInterface, QueryRunner } from 'typeorm';
|
||||
|
||||
const vectorExtension = new ConfigRepository().getEnv().database.vectorExtension;
|
||||
|
||||
export class AddFaceEmbeddingIndex1700714033632 implements MigrationInterface {
|
||||
name = 'AddFaceEmbeddingIndex1700714033632';
|
||||
|
||||
public async up(queryRunner: QueryRunner): Promise<void> {
|
||||
const vectorExtension = await getVectorExtension(queryRunner);
|
||||
await queryRunner.query(`SET search_path TO "$user", public, vectors`);
|
||||
|
||||
await queryRunner.query(vectorIndexQuery({ vectorExtension, table: 'asset_faces', indexName: 'face_index' }));
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
import { DatabaseExtension } from 'src/enum';
|
||||
import { ConfigRepository } from 'src/repositories/config.repository';
|
||||
import { getVectorExtension } from 'src/repositories/database.repository';
|
||||
import { vectorIndexQuery } from 'src/utils/database';
|
||||
import { MigrationInterface, QueryRunner } from 'typeorm';
|
||||
|
||||
const vectorExtension = new ConfigRepository().getEnv().database.vectorExtension;
|
||||
|
||||
export class AddFaceSearchRelation1718486162779 implements MigrationInterface {
|
||||
public async up(queryRunner: QueryRunner): Promise<void> {
|
||||
const vectorExtension = await getVectorExtension(queryRunner);
|
||||
if (vectorExtension === DatabaseExtension.VECTORS) {
|
||||
await queryRunner.query(`SET search_path TO "$user", public, vectors`);
|
||||
}
|
||||
|
@ -48,11 +47,11 @@ export class AddFaceSearchRelation1718486162779 implements MigrationInterface {
|
|||
await queryRunner.query(`ALTER TABLE face_search ALTER COLUMN embedding SET DATA TYPE vector(512)`);
|
||||
|
||||
await queryRunner.query(vectorIndexQuery({ vectorExtension, table: 'smart_search', indexName: 'clip_index' }));
|
||||
|
||||
await queryRunner.query(vectorIndexQuery({ vectorExtension, table: 'face_search', indexName: 'face_index' }));
|
||||
}
|
||||
|
||||
public async down(queryRunner: QueryRunner): Promise<void> {
|
||||
const vectorExtension = await getVectorExtension(queryRunner);
|
||||
if (vectorExtension === DatabaseExtension.VECTORS) {
|
||||
await queryRunner.query(`SET search_path TO "$user", public, vectors`);
|
||||
}
|
||||
|
|
|
@ -11,11 +11,3 @@ WHERE
|
|||
|
||||
-- DatabaseRepository.getPostgresVersion
|
||||
SHOW server_version
|
||||
|
||||
-- DatabaseRepository.shouldReindex
|
||||
SELECT
|
||||
idx_status
|
||||
FROM
|
||||
pg_vector_index_stat
|
||||
WHERE
|
||||
indexname = $1
|
||||
|
|
|
@ -204,6 +204,21 @@ where
|
|||
"person"."ownerId" = $3
|
||||
and "asset_faces"."deletedAt" is null
|
||||
|
||||
-- PersonRepository.refreshFaces
|
||||
with
|
||||
"added_embeddings" as (
|
||||
insert into
|
||||
"face_search" ("faceId", "embedding")
|
||||
values
|
||||
($1, $2)
|
||||
)
|
||||
select
|
||||
from
|
||||
(
|
||||
select
|
||||
1
|
||||
) as "dummy"
|
||||
|
||||
-- PersonRepository.getFacesByIds
|
||||
select
|
||||
"asset_faces".*,
|
||||
|
|
|
@ -64,6 +64,9 @@ limit
|
|||
$15
|
||||
|
||||
-- SearchRepository.searchSmart
|
||||
begin
|
||||
set
|
||||
local vchordrq.probes = 1
|
||||
select
|
||||
"assets".*
|
||||
from
|
||||
|
@ -83,8 +86,12 @@ limit
|
|||
$7
|
||||
offset
|
||||
$8
|
||||
commit
|
||||
|
||||
-- SearchRepository.searchDuplicates
|
||||
begin
|
||||
set
|
||||
local vchordrq.probes = 1
|
||||
with
|
||||
"cte" as (
|
||||
select
|
||||
|
@ -102,18 +109,22 @@ with
|
|||
and "assets"."id" != $5::uuid
|
||||
and "assets"."stackId" is null
|
||||
order by
|
||||
smart_search.embedding <=> $6
|
||||
"distance"
|
||||
limit
|
||||
$7
|
||||
$6
|
||||
)
|
||||
select
|
||||
*
|
||||
from
|
||||
"cte"
|
||||
where
|
||||
"cte"."distance" <= $8
|
||||
"cte"."distance" <= $7
|
||||
commit
|
||||
|
||||
-- SearchRepository.searchFaces
|
||||
begin
|
||||
set
|
||||
local vchordrq.probes = 1
|
||||
with
|
||||
"cte" as (
|
||||
select
|
||||
|
@ -129,16 +140,17 @@ with
|
|||
"assets"."ownerId" = any ($2::uuid[])
|
||||
and "assets"."deletedAt" is null
|
||||
order by
|
||||
face_search.embedding <=> $3
|
||||
"distance"
|
||||
limit
|
||||
$4
|
||||
$3
|
||||
)
|
||||
select
|
||||
*
|
||||
from
|
||||
"cte"
|
||||
where
|
||||
"cte"."distance" <= $5
|
||||
"cte"."distance" <= $4
|
||||
commit
|
||||
|
||||
-- SearchRepository.searchPlaces
|
||||
select
|
||||
|
|
|
@ -89,7 +89,7 @@ describe('getEnv', () => {
|
|||
password: 'postgres',
|
||||
},
|
||||
skipMigrations: false,
|
||||
vectorExtension: 'vectors',
|
||||
vectorExtension: undefined,
|
||||
});
|
||||
});
|
||||
|
||||
|
|
|
@ -58,7 +58,7 @@ export interface EnvData {
|
|||
database: {
|
||||
config: DatabaseConnectionParams;
|
||||
skipMigrations: boolean;
|
||||
vectorExtension: VectorExtension;
|
||||
vectorExtension?: VectorExtension;
|
||||
};
|
||||
|
||||
licensePublicKey: {
|
||||
|
@ -196,6 +196,22 @@ const getEnv = (): EnvData => {
|
|||
ssl: dto.DB_SSL_MODE || undefined,
|
||||
};
|
||||
|
||||
let vectorExtension: VectorExtension | undefined;
|
||||
switch (dto.DB_VECTOR_EXTENSION) {
|
||||
case 'pgvector': {
|
||||
vectorExtension = DatabaseExtension.VECTOR;
|
||||
break;
|
||||
}
|
||||
case 'pgvecto.rs': {
|
||||
vectorExtension = DatabaseExtension.VECTORS;
|
||||
break;
|
||||
}
|
||||
case 'vectorchord': {
|
||||
vectorExtension = DatabaseExtension.VECTORCHORD;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
host: dto.IMMICH_HOST,
|
||||
port: dto.IMMICH_PORT || 2283,
|
||||
|
@ -251,7 +267,7 @@ const getEnv = (): EnvData => {
|
|||
database: {
|
||||
config: databaseConnection,
|
||||
skipMigrations: dto.DB_SKIP_MIGRATIONS ?? false,
|
||||
vectorExtension: dto.DB_VECTOR_EXTENSION === 'pgvector' ? DatabaseExtension.VECTOR : DatabaseExtension.VECTORS,
|
||||
vectorExtension,
|
||||
},
|
||||
|
||||
licensePublicKey: isProd ? productionKeys : stagingKeys,
|
||||
|
|
|
@ -5,7 +5,16 @@ import { InjectKysely } from 'nestjs-kysely';
|
|||
import { readdir } from 'node:fs/promises';
|
||||
import { join, resolve } from 'node:path';
|
||||
import semver from 'semver';
|
||||
import { EXTENSION_NAMES, POSTGRES_VERSION_RANGE, VECTOR_VERSION_RANGE, VECTORS_VERSION_RANGE } from 'src/constants';
|
||||
import {
|
||||
EXTENSION_NAMES,
|
||||
POSTGRES_VERSION_RANGE,
|
||||
VECTOR_EXTENSIONS,
|
||||
VECTOR_INDEX_TABLES,
|
||||
VECTOR_VERSION_RANGE,
|
||||
VECTORCHORD_LIST_SLACK_FACTOR,
|
||||
VECTORCHORD_VERSION_RANGE,
|
||||
VECTORS_VERSION_RANGE,
|
||||
} from 'src/constants';
|
||||
import { DB } from 'src/db';
|
||||
import { GenerateSql } from 'src/decorators';
|
||||
import { DatabaseExtension, DatabaseLock, VectorIndex } from 'src/enum';
|
||||
|
@ -14,11 +23,42 @@ import { LoggingRepository } from 'src/repositories/logging.repository';
|
|||
import { ExtensionVersion, VectorExtension, VectorUpdateResult } from 'src/types';
|
||||
import { vectorIndexQuery } from 'src/utils/database';
|
||||
import { isValidInteger } from 'src/validation';
|
||||
import { DataSource } from 'typeorm';
|
||||
import { DataSource, QueryRunner } from 'typeorm';
|
||||
|
||||
export let cachedVectorExtension: VectorExtension | undefined;
|
||||
export async function getVectorExtension(runner: Kysely<DB> | QueryRunner): Promise<VectorExtension> {
|
||||
if (cachedVectorExtension) {
|
||||
return cachedVectorExtension;
|
||||
}
|
||||
|
||||
cachedVectorExtension = new ConfigRepository().getEnv().database.vectorExtension;
|
||||
if (cachedVectorExtension) {
|
||||
return cachedVectorExtension;
|
||||
}
|
||||
|
||||
let availableExtensions: { name: VectorExtension }[];
|
||||
const query = `SELECT name FROM pg_available_extensions WHERE name IN (${VECTOR_EXTENSIONS.map((ext) => `'${ext}'`).join(', ')})`;
|
||||
if (runner instanceof Kysely) {
|
||||
const { rows } = await sql.raw<{ name: VectorExtension }>(query).execute(runner);
|
||||
availableExtensions = rows;
|
||||
} else {
|
||||
availableExtensions = (await runner.query(query)) as { name: VectorExtension }[];
|
||||
}
|
||||
const extensionNames = new Set(availableExtensions.map((row) => row.name));
|
||||
cachedVectorExtension = VECTOR_EXTENSIONS.find((ext) => extensionNames.has(ext));
|
||||
if (!cachedVectorExtension) {
|
||||
throw new Error(`No vector extension found. Available extensions: ${VECTOR_EXTENSIONS.join(', ')}`);
|
||||
}
|
||||
return cachedVectorExtension;
|
||||
}
|
||||
|
||||
export const probes: Record<VectorIndex, number> = {
|
||||
[VectorIndex.CLIP]: 1,
|
||||
[VectorIndex.FACE]: 1,
|
||||
};
|
||||
|
||||
@Injectable()
|
||||
export class DatabaseRepository {
|
||||
private vectorExtension: VectorExtension;
|
||||
private readonly asyncLock = new AsyncLock();
|
||||
|
||||
constructor(
|
||||
|
@ -26,7 +66,6 @@ export class DatabaseRepository {
|
|||
private logger: LoggingRepository,
|
||||
private configRepository: ConfigRepository,
|
||||
) {
|
||||
this.vectorExtension = configRepository.getEnv().database.vectorExtension;
|
||||
this.logger.setContext(DatabaseRepository.name);
|
||||
}
|
||||
|
||||
|
@ -34,6 +73,10 @@ export class DatabaseRepository {
|
|||
await this.db.destroy();
|
||||
}
|
||||
|
||||
getVectorExtension(): Promise<VectorExtension> {
|
||||
return getVectorExtension(this.db);
|
||||
}
|
||||
|
||||
@GenerateSql({ params: [DatabaseExtension.VECTORS] })
|
||||
async getExtensionVersion(extension: DatabaseExtension): Promise<ExtensionVersion> {
|
||||
const { rows } = await sql<ExtensionVersion>`
|
||||
|
@ -45,7 +88,20 @@ export class DatabaseRepository {
|
|||
}
|
||||
|
||||
getExtensionVersionRange(extension: VectorExtension): string {
|
||||
return extension === DatabaseExtension.VECTORS ? VECTORS_VERSION_RANGE : VECTOR_VERSION_RANGE;
|
||||
switch (extension) {
|
||||
case DatabaseExtension.VECTORCHORD: {
|
||||
return VECTORCHORD_VERSION_RANGE;
|
||||
}
|
||||
case DatabaseExtension.VECTORS: {
|
||||
return VECTORS_VERSION_RANGE;
|
||||
}
|
||||
case DatabaseExtension.VECTOR: {
|
||||
return VECTOR_VERSION_RANGE;
|
||||
}
|
||||
default: {
|
||||
throw new Error(`Unsupported vector extension: '${extension}'`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@GenerateSql()
|
||||
|
@ -59,7 +115,14 @@ export class DatabaseRepository {
|
|||
}
|
||||
|
||||
async createExtension(extension: DatabaseExtension): Promise<void> {
|
||||
await sql`CREATE EXTENSION IF NOT EXISTS ${sql.raw(extension)}`.execute(this.db);
|
||||
await sql`CREATE EXTENSION IF NOT EXISTS ${sql.raw(extension)} CASCADE`.execute(this.db);
|
||||
if (extension === DatabaseExtension.VECTORCHORD) {
|
||||
const dbName = sql.table(await this.getDatabaseName());
|
||||
await sql`ALTER DATABASE ${dbName} SET vchordrq.prewarm_dim = '512,640,768,1024,1152,1536'`.execute(this.db);
|
||||
await sql`SET vchordrq.prewarm_dim = '512,640,768,1024,1152,1536'`.execute(this.db);
|
||||
await sql`ALTER DATABASE ${dbName} SET vchordrq.probes = 1`.execute(this.db);
|
||||
await sql`SET vchordrq.probes = 1`.execute(this.db);
|
||||
}
|
||||
}
|
||||
|
||||
async updateVectorExtension(extension: VectorExtension, targetVersion?: string): Promise<VectorUpdateResult> {
|
||||
|
@ -78,120 +141,201 @@ export class DatabaseRepository {
|
|||
await this.db.transaction().execute(async (tx) => {
|
||||
await this.setSearchPath(tx);
|
||||
|
||||
if (isVectors && installedVersion === '0.1.1') {
|
||||
await this.setExtVersion(tx, DatabaseExtension.VECTORS, '0.1.11');
|
||||
}
|
||||
|
||||
const isSchemaUpgrade = semver.satisfies(installedVersion, '0.1.1 || 0.1.11');
|
||||
if (isSchemaUpgrade && isVectors) {
|
||||
await this.updateVectorsSchema(tx);
|
||||
}
|
||||
|
||||
await sql`ALTER EXTENSION ${sql.raw(extension)} UPDATE TO ${sql.lit(targetVersion)}`.execute(tx);
|
||||
|
||||
const diff = semver.diff(installedVersion, targetVersion);
|
||||
if (isVectors && diff && ['minor', 'major'].includes(diff)) {
|
||||
if (isVectors && (diff === 'major' || diff === 'minor')) {
|
||||
await sql`SELECT pgvectors_upgrade()`.execute(tx);
|
||||
restartRequired = true;
|
||||
} else {
|
||||
await this.reindex(VectorIndex.CLIP);
|
||||
await this.reindex(VectorIndex.FACE);
|
||||
} else if (diff) {
|
||||
await Promise.all([this.reindexVectors(VectorIndex.CLIP), this.reindexVectors(VectorIndex.FACE)]);
|
||||
}
|
||||
});
|
||||
|
||||
return { restartRequired };
|
||||
}
|
||||
|
||||
async reindex(index: VectorIndex): Promise<void> {
|
||||
try {
|
||||
await sql`REINDEX INDEX ${sql.raw(index)}`.execute(this.db);
|
||||
} catch (error) {
|
||||
if (this.vectorExtension !== DatabaseExtension.VECTORS) {
|
||||
throw error;
|
||||
}
|
||||
this.logger.warn(`Could not reindex index ${index}. Attempting to auto-fix.`);
|
||||
async prewarm(index: VectorIndex): Promise<void> {
|
||||
const vectorExtension = await getVectorExtension(this.db);
|
||||
if (vectorExtension !== DatabaseExtension.VECTORCHORD) {
|
||||
return;
|
||||
}
|
||||
this.logger.debug(`Prewarming ${index}`);
|
||||
await sql`SELECT vchordrq_prewarm(${index})`.execute(this.db);
|
||||
}
|
||||
|
||||
const table = await this.getIndexTable(index);
|
||||
const dimSize = await this.getDimSize(table);
|
||||
await this.db.transaction().execute(async (tx) => {
|
||||
await this.setSearchPath(tx);
|
||||
await sql`DROP INDEX IF EXISTS ${sql.raw(index)}`.execute(tx);
|
||||
await sql`ALTER TABLE ${sql.raw(table)} ALTER COLUMN embedding SET DATA TYPE real[]`.execute(tx);
|
||||
await sql`ALTER TABLE ${sql.raw(table)} ALTER COLUMN embedding SET DATA TYPE vector(${sql.raw(String(dimSize))})`.execute(
|
||||
tx,
|
||||
);
|
||||
await sql.raw(vectorIndexQuery({ vectorExtension: this.vectorExtension, table, indexName: index })).execute(tx);
|
||||
});
|
||||
async reindexVectorsIfNeeded(names: VectorIndex[]): Promise<void> {
|
||||
const { rows } = await sql<{
|
||||
indexdef: string;
|
||||
indexname: string;
|
||||
}>`SELECT indexdef, indexname FROM pg_indexes WHERE indexname = ANY(ARRAY[${sql.join(names)}])`.execute(this.db);
|
||||
|
||||
const vectorExtension = await getVectorExtension(this.db);
|
||||
|
||||
const promises = [];
|
||||
for (const indexName of names) {
|
||||
const row = rows.find((index) => index.indexname === indexName);
|
||||
const table = VECTOR_INDEX_TABLES[indexName];
|
||||
if (!row) {
|
||||
promises.push(this.reindexVectors(indexName));
|
||||
continue;
|
||||
}
|
||||
|
||||
switch (vectorExtension) {
|
||||
case DatabaseExtension.VECTOR: {
|
||||
if (!row.indexdef.toLowerCase().includes('using hnsw')) {
|
||||
promises.push(this.reindexVectors(indexName));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case DatabaseExtension.VECTORS: {
|
||||
if (!row.indexdef.toLowerCase().includes('using vectors')) {
|
||||
promises.push(this.reindexVectors(indexName));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case DatabaseExtension.VECTORCHORD: {
|
||||
const matches = row.indexdef.match(/(?<=lists = \[)\d+/g);
|
||||
const lists = matches && matches.length > 0 ? Number(matches[0]) : 1;
|
||||
promises.push(
|
||||
this.db
|
||||
.selectFrom(this.db.dynamic.table(table).as('t'))
|
||||
.select((eb) => eb.fn.countAll<number>().as('count'))
|
||||
.executeTakeFirstOrThrow()
|
||||
.then(({ count }) => {
|
||||
const targetLists = this.targetListCount(count);
|
||||
this.logger.log(`targetLists=${targetLists}, current=${lists} for ${indexName} of ${count} rows`);
|
||||
if (
|
||||
!row.indexdef.toLowerCase().includes('using vchordrq') ||
|
||||
// slack factor is to avoid frequent reindexing if the count is borderline
|
||||
(lists !== targetLists && lists !== this.targetListCount(count * VECTORCHORD_LIST_SLACK_FACTOR))
|
||||
) {
|
||||
probes[indexName] = this.targetProbeCount(targetLists);
|
||||
return this.reindexVectors(indexName, { lists: targetLists });
|
||||
} else {
|
||||
probes[indexName] = this.targetProbeCount(lists);
|
||||
}
|
||||
}),
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (promises.length > 0) {
|
||||
await Promise.all(promises);
|
||||
}
|
||||
}
|
||||
|
||||
@GenerateSql({ params: [VectorIndex.CLIP] })
|
||||
async shouldReindex(name: VectorIndex): Promise<boolean> {
|
||||
if (this.vectorExtension !== DatabaseExtension.VECTORS) {
|
||||
return false;
|
||||
private async reindexVectors(indexName: VectorIndex, { lists }: { lists?: number } = {}): Promise<void> {
|
||||
this.logger.log(`Reindexing ${indexName}`);
|
||||
const table = VECTOR_INDEX_TABLES[indexName];
|
||||
const vectorExtension = await getVectorExtension(this.db);
|
||||
const { rows } = await sql<{
|
||||
columnName: string;
|
||||
}>`SELECT column_name as "columnName" FROM information_schema.columns WHERE table_name = ${table}`.execute(this.db);
|
||||
if (rows.length === 0) {
|
||||
this.logger.warn(
|
||||
`Table ${table} does not exist, skipping reindexing. This is only normal if this is a new Immich instance.`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const { rows } = await sql<{
|
||||
idx_status: string;
|
||||
}>`SELECT idx_status FROM pg_vector_index_stat WHERE indexname = ${name}`.execute(this.db);
|
||||
return rows[0]?.idx_status === 'UPGRADE';
|
||||
} catch (error) {
|
||||
const message: string = (error as any).message;
|
||||
if (message.includes('index is not existing')) {
|
||||
return true;
|
||||
} else if (message.includes('relation "pg_vector_index_stat" does not exist')) {
|
||||
return false;
|
||||
const dimSize = await this.getDimensionSize(table);
|
||||
await this.db.transaction().execute(async (tx) => {
|
||||
await sql`DROP INDEX IF EXISTS ${sql.raw(indexName)}`.execute(tx);
|
||||
if (!rows.some((row) => row.columnName === 'embedding')) {
|
||||
this.logger.warn(`Column 'embedding' does not exist in table '${table}', truncating and adding column.`);
|
||||
await sql`TRUNCATE TABLE ${sql.raw(table)}`.execute(tx);
|
||||
await sql`ALTER TABLE ${sql.raw(table)} ADD COLUMN embedding real[] NOT NULL`.execute(tx);
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
await sql`ALTER TABLE ${sql.raw(table)} ALTER COLUMN embedding SET DATA TYPE real[]`.execute(tx);
|
||||
const schema = vectorExtension === DatabaseExtension.VECTORS ? 'vectors.' : '';
|
||||
await sql`
|
||||
ALTER TABLE ${sql.raw(table)}
|
||||
ALTER COLUMN embedding
|
||||
SET DATA TYPE ${sql.raw(schema)}vector(${sql.raw(String(dimSize))})`.execute(tx);
|
||||
await sql.raw(vectorIndexQuery({ vectorExtension, table, indexName, lists })).execute(tx);
|
||||
});
|
||||
this.logger.log(`Reindexed ${indexName}`);
|
||||
}
|
||||
|
||||
private async setSearchPath(tx: Transaction<DB>): Promise<void> {
|
||||
await sql`SET search_path TO "$user", public, vectors`.execute(tx);
|
||||
}
|
||||
|
||||
private async setExtVersion(tx: Transaction<DB>, extName: DatabaseExtension, version: string): Promise<void> {
|
||||
await sql`UPDATE pg_catalog.pg_extension SET extversion = ${version} WHERE extname = ${extName}`.execute(tx);
|
||||
private async getDatabaseName(): Promise<string> {
|
||||
const { rows } = await sql<{ db: string }>`SELECT current_database() as db`.execute(this.db);
|
||||
return rows[0].db;
|
||||
}
|
||||
|
||||
private async getIndexTable(index: VectorIndex): Promise<string> {
|
||||
const { rows } = await sql<{
|
||||
relname: string | null;
|
||||
}>`SELECT relname FROM pg_stat_all_indexes WHERE indexrelname = ${index}`.execute(this.db);
|
||||
const table = rows[0]?.relname;
|
||||
if (!table) {
|
||||
throw new Error(`Could not find table for index ${index}`);
|
||||
}
|
||||
return table;
|
||||
}
|
||||
|
||||
private async updateVectorsSchema(tx: Transaction<DB>): Promise<void> {
|
||||
const extension = DatabaseExtension.VECTORS;
|
||||
await sql`CREATE SCHEMA IF NOT EXISTS ${extension}`.execute(tx);
|
||||
await sql`UPDATE pg_catalog.pg_extension SET extrelocatable = true WHERE extname = ${extension}`.execute(tx);
|
||||
await sql`ALTER EXTENSION vectors SET SCHEMA vectors`.execute(tx);
|
||||
await sql`UPDATE pg_catalog.pg_extension SET extrelocatable = false WHERE extname = ${extension}`.execute(tx);
|
||||
}
|
||||
|
||||
private async getDimSize(table: string, column = 'embedding'): Promise<number> {
|
||||
async getDimensionSize(table: string, column = 'embedding'): Promise<number> {
|
||||
const { rows } = await sql<{ dimsize: number }>`
|
||||
SELECT atttypmod as dimsize
|
||||
FROM pg_attribute f
|
||||
JOIN pg_class c ON c.oid = f.attrelid
|
||||
WHERE c.relkind = 'r'::char
|
||||
AND f.attnum > 0
|
||||
AND c.relname = ${table}
|
||||
AND f.attname = '${column}'
|
||||
AND c.relname = ${table}::text
|
||||
AND f.attname = ${column}::text
|
||||
`.execute(this.db);
|
||||
|
||||
const dimSize = rows[0]?.dimsize;
|
||||
if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) {
|
||||
throw new Error(`Could not retrieve dimension size`);
|
||||
this.logger.warn(`Could not retrieve dimension size of column '${column}' in table '${table}', assuming 512`);
|
||||
return 512;
|
||||
}
|
||||
return dimSize;
|
||||
}
|
||||
|
||||
async setDimensionSize(dimSize: number): Promise<void> {
|
||||
if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) {
|
||||
throw new Error(`Invalid CLIP dimension size: ${dimSize}`);
|
||||
}
|
||||
|
||||
// this is done in two transactions to handle concurrent writes
|
||||
await this.db.transaction().execute(async (trx) => {
|
||||
await sql`delete from ${sql.table('smart_search')}`.execute(trx);
|
||||
await trx.schema.alterTable('smart_search').dropConstraint('dim_size_constraint').ifExists().execute();
|
||||
await sql`alter table ${sql.table('smart_search')} add constraint dim_size_constraint check (array_length(embedding::real[], 1) = ${sql.lit(dimSize)})`.execute(
|
||||
trx,
|
||||
);
|
||||
});
|
||||
|
||||
const vectorExtension = await this.getVectorExtension();
|
||||
await this.db.transaction().execute(async (trx) => {
|
||||
await sql`drop index if exists clip_index`.execute(trx);
|
||||
await trx.schema
|
||||
.alterTable('smart_search')
|
||||
.alterColumn('embedding', (col) => col.setDataType(sql.raw(`vector(${dimSize})`)))
|
||||
.execute();
|
||||
await sql
|
||||
.raw(vectorIndexQuery({ vectorExtension, table: 'smart_search', indexName: VectorIndex.CLIP }))
|
||||
.execute(trx);
|
||||
await trx.schema.alterTable('smart_search').dropConstraint('dim_size_constraint').ifExists().execute();
|
||||
});
|
||||
probes[VectorIndex.CLIP] = 1;
|
||||
|
||||
await sql`vacuum analyze ${sql.table('smart_search')}`.execute(this.db);
|
||||
}
|
||||
|
||||
async deleteAllSearchEmbeddings(): Promise<void> {
|
||||
await sql`truncate ${sql.table('smart_search')}`.execute(this.db);
|
||||
}
|
||||
|
||||
private targetListCount(count: number) {
|
||||
if (count < 128_000) {
|
||||
return 1;
|
||||
} else if (count < 2_048_000) {
|
||||
return 1 << (32 - Math.clz32(count / 1000));
|
||||
} else {
|
||||
return 1 << (33 - Math.clz32(Math.sqrt(count)));
|
||||
}
|
||||
}
|
||||
|
||||
private targetProbeCount(lists: number) {
|
||||
return Math.ceil(lists / 8);
|
||||
}
|
||||
|
||||
async runMigrations(options?: { transaction?: 'all' | 'none' | 'each' }): Promise<void> {
|
||||
const { database } = this.configRepository.getEnv();
|
||||
|
||||
|
|
|
@ -398,6 +398,7 @@ export class PersonRepository {
|
|||
return results.map(({ id }) => id);
|
||||
}
|
||||
|
||||
@GenerateSql({ params: [[], [], [{ faceId: DummyValue.UUID, embedding: DummyValue.VECTOR }]] })
|
||||
async refreshFaces(
|
||||
facesToAdd: (Insertable<AssetFaces> & { assetId: string })[],
|
||||
faceIdsToRemove: string[],
|
||||
|
|
|
@ -5,9 +5,9 @@ import { randomUUID } from 'node:crypto';
|
|||
import { DB, Exif } from 'src/db';
|
||||
import { DummyValue, GenerateSql } from 'src/decorators';
|
||||
import { MapAsset } from 'src/dtos/asset-response.dto';
|
||||
import { AssetStatus, AssetType, AssetVisibility } from 'src/enum';
|
||||
import { ConfigRepository } from 'src/repositories/config.repository';
|
||||
import { anyUuid, asUuid, searchAssetBuilder, vectorIndexQuery } from 'src/utils/database';
|
||||
import { AssetStatus, AssetType, AssetVisibility, VectorIndex } from 'src/enum';
|
||||
import { probes } from 'src/repositories/database.repository';
|
||||
import { anyUuid, asUuid, searchAssetBuilder } from 'src/utils/database';
|
||||
import { paginationHelper } from 'src/utils/pagination';
|
||||
import { isValidInteger } from 'src/validation';
|
||||
|
||||
|
@ -168,10 +168,7 @@ export interface GetCameraMakesOptions {
|
|||
|
||||
@Injectable()
|
||||
export class SearchRepository {
|
||||
constructor(
|
||||
@InjectKysely() private db: Kysely<DB>,
|
||||
private configRepository: ConfigRepository,
|
||||
) {}
|
||||
constructor(@InjectKysely() private db: Kysely<DB>) {}
|
||||
|
||||
@GenerateSql({
|
||||
params: [
|
||||
|
@ -236,19 +233,21 @@ export class SearchRepository {
|
|||
},
|
||||
],
|
||||
})
|
||||
async searchSmart(pagination: SearchPaginationOptions, options: SmartSearchOptions) {
|
||||
searchSmart(pagination: SearchPaginationOptions, options: SmartSearchOptions) {
|
||||
if (!isValidInteger(pagination.size, { min: 1, max: 1000 })) {
|
||||
throw new Error(`Invalid value for 'size': ${pagination.size}`);
|
||||
}
|
||||
|
||||
const items = await searchAssetBuilder(this.db, options)
|
||||
.innerJoin('smart_search', 'assets.id', 'smart_search.assetId')
|
||||
.orderBy(sql`smart_search.embedding <=> ${options.embedding}`)
|
||||
.limit(pagination.size + 1)
|
||||
.offset((pagination.page - 1) * pagination.size)
|
||||
.execute();
|
||||
|
||||
return paginationHelper(items, pagination.size);
|
||||
return this.db.transaction().execute(async (trx) => {
|
||||
await sql`set local vchordrq.probes = ${sql.lit(probes[VectorIndex.CLIP])}`.execute(trx);
|
||||
const items = await searchAssetBuilder(trx, options)
|
||||
.innerJoin('smart_search', 'assets.id', 'smart_search.assetId')
|
||||
.orderBy(sql`smart_search.embedding <=> ${options.embedding}`)
|
||||
.limit(pagination.size + 1)
|
||||
.offset((pagination.page - 1) * pagination.size)
|
||||
.execute();
|
||||
return paginationHelper(items, pagination.size);
|
||||
});
|
||||
}
|
||||
|
||||
@GenerateSql({
|
||||
|
@ -263,29 +262,32 @@ export class SearchRepository {
|
|||
],
|
||||
})
|
||||
searchDuplicates({ assetId, embedding, maxDistance, type, userIds }: AssetDuplicateSearch) {
|
||||
return this.db
|
||||
.with('cte', (qb) =>
|
||||
qb
|
||||
.selectFrom('assets')
|
||||
.select([
|
||||
'assets.id as assetId',
|
||||
'assets.duplicateId',
|
||||
sql<number>`smart_search.embedding <=> ${embedding}`.as('distance'),
|
||||
])
|
||||
.innerJoin('smart_search', 'assets.id', 'smart_search.assetId')
|
||||
.where('assets.ownerId', '=', anyUuid(userIds))
|
||||
.where('assets.deletedAt', 'is', null)
|
||||
.where('assets.visibility', '!=', AssetVisibility.HIDDEN)
|
||||
.where('assets.type', '=', type)
|
||||
.where('assets.id', '!=', asUuid(assetId))
|
||||
.where('assets.stackId', 'is', null)
|
||||
.orderBy(sql`smart_search.embedding <=> ${embedding}`)
|
||||
.limit(64),
|
||||
)
|
||||
.selectFrom('cte')
|
||||
.selectAll()
|
||||
.where('cte.distance', '<=', maxDistance as number)
|
||||
.execute();
|
||||
return this.db.transaction().execute(async (trx) => {
|
||||
await sql`set local vchordrq.probes = ${sql.lit(probes[VectorIndex.CLIP])}`.execute(trx);
|
||||
return await trx
|
||||
.with('cte', (qb) =>
|
||||
qb
|
||||
.selectFrom('assets')
|
||||
.select([
|
||||
'assets.id as assetId',
|
||||
'assets.duplicateId',
|
||||
sql<number>`smart_search.embedding <=> ${embedding}`.as('distance'),
|
||||
])
|
||||
.innerJoin('smart_search', 'assets.id', 'smart_search.assetId')
|
||||
.where('assets.ownerId', '=', anyUuid(userIds))
|
||||
.where('assets.deletedAt', 'is', null)
|
||||
.where('assets.visibility', '!=', AssetVisibility.HIDDEN)
|
||||
.where('assets.type', '=', type)
|
||||
.where('assets.id', '!=', asUuid(assetId))
|
||||
.where('assets.stackId', 'is', null)
|
||||
.orderBy('distance')
|
||||
.limit(64),
|
||||
)
|
||||
.selectFrom('cte')
|
||||
.selectAll()
|
||||
.where('cte.distance', '<=', maxDistance as number)
|
||||
.execute();
|
||||
});
|
||||
}
|
||||
|
||||
@GenerateSql({
|
||||
|
@ -303,31 +305,36 @@ export class SearchRepository {
|
|||
throw new Error(`Invalid value for 'numResults': ${numResults}`);
|
||||
}
|
||||
|
||||
return this.db
|
||||
.with('cte', (qb) =>
|
||||
qb
|
||||
.selectFrom('asset_faces')
|
||||
.select([
|
||||
'asset_faces.id',
|
||||
'asset_faces.personId',
|
||||
sql<number>`face_search.embedding <=> ${embedding}`.as('distance'),
|
||||
])
|
||||
.innerJoin('assets', 'assets.id', 'asset_faces.assetId')
|
||||
.innerJoin('face_search', 'face_search.faceId', 'asset_faces.id')
|
||||
.leftJoin('person', 'person.id', 'asset_faces.personId')
|
||||
.where('assets.ownerId', '=', anyUuid(userIds))
|
||||
.where('assets.deletedAt', 'is', null)
|
||||
.$if(!!hasPerson, (qb) => qb.where('asset_faces.personId', 'is not', null))
|
||||
.$if(!!minBirthDate, (qb) =>
|
||||
qb.where((eb) => eb.or([eb('person.birthDate', 'is', null), eb('person.birthDate', '<=', minBirthDate!)])),
|
||||
)
|
||||
.orderBy(sql`face_search.embedding <=> ${embedding}`)
|
||||
.limit(numResults),
|
||||
)
|
||||
.selectFrom('cte')
|
||||
.selectAll()
|
||||
.where('cte.distance', '<=', maxDistance)
|
||||
.execute();
|
||||
return this.db.transaction().execute(async (trx) => {
|
||||
await sql`set local vchordrq.probes = ${sql.lit(probes[VectorIndex.FACE])}`.execute(trx);
|
||||
return await trx
|
||||
.with('cte', (qb) =>
|
||||
qb
|
||||
.selectFrom('asset_faces')
|
||||
.select([
|
||||
'asset_faces.id',
|
||||
'asset_faces.personId',
|
||||
sql<number>`face_search.embedding <=> ${embedding}`.as('distance'),
|
||||
])
|
||||
.innerJoin('assets', 'assets.id', 'asset_faces.assetId')
|
||||
.innerJoin('face_search', 'face_search.faceId', 'asset_faces.id')
|
||||
.leftJoin('person', 'person.id', 'asset_faces.personId')
|
||||
.where('assets.ownerId', '=', anyUuid(userIds))
|
||||
.where('assets.deletedAt', 'is', null)
|
||||
.$if(!!hasPerson, (qb) => qb.where('asset_faces.personId', 'is not', null))
|
||||
.$if(!!minBirthDate, (qb) =>
|
||||
qb.where((eb) =>
|
||||
eb.or([eb('person.birthDate', 'is', null), eb('person.birthDate', '<=', minBirthDate!)]),
|
||||
),
|
||||
)
|
||||
.orderBy('distance')
|
||||
.limit(numResults),
|
||||
)
|
||||
.selectFrom('cte')
|
||||
.selectAll()
|
||||
.where('cte.distance', '<=', maxDistance)
|
||||
.execute();
|
||||
});
|
||||
}
|
||||
|
||||
@GenerateSql({ params: [DummyValue.STRING] })
|
||||
|
@ -416,56 +423,6 @@ export class SearchRepository {
|
|||
.execute();
|
||||
}
|
||||
|
||||
async getDimensionSize(): Promise<number> {
|
||||
const { rows } = await sql<{ dimsize: number }>`
|
||||
select atttypmod as dimsize
|
||||
from pg_attribute f
|
||||
join pg_class c ON c.oid = f.attrelid
|
||||
where c.relkind = 'r'::char
|
||||
and f.attnum > 0
|
||||
and c.relname = 'smart_search'
|
||||
and f.attname = 'embedding'
|
||||
`.execute(this.db);
|
||||
|
||||
const dimSize = rows[0]['dimsize'];
|
||||
if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) {
|
||||
throw new Error(`Could not retrieve CLIP dimension size`);
|
||||
}
|
||||
return dimSize;
|
||||
}
|
||||
|
||||
async setDimensionSize(dimSize: number): Promise<void> {
|
||||
if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) {
|
||||
throw new Error(`Invalid CLIP dimension size: ${dimSize}`);
|
||||
}
|
||||
|
||||
// this is done in two transactions to handle concurrent writes
|
||||
await this.db.transaction().execute(async (trx) => {
|
||||
await sql`delete from ${sql.table('smart_search')}`.execute(trx);
|
||||
await trx.schema.alterTable('smart_search').dropConstraint('dim_size_constraint').ifExists().execute();
|
||||
await sql`alter table ${sql.table('smart_search')} add constraint dim_size_constraint check (array_length(embedding::real[], 1) = ${sql.lit(dimSize)})`.execute(
|
||||
trx,
|
||||
);
|
||||
});
|
||||
|
||||
const vectorExtension = this.configRepository.getEnv().database.vectorExtension;
|
||||
await this.db.transaction().execute(async (trx) => {
|
||||
await sql`drop index if exists clip_index`.execute(trx);
|
||||
await trx.schema
|
||||
.alterTable('smart_search')
|
||||
.alterColumn('embedding', (col) => col.setDataType(sql.raw(`vector(${dimSize})`)))
|
||||
.execute();
|
||||
await sql.raw(vectorIndexQuery({ vectorExtension, table: 'smart_search', indexName: 'clip_index' })).execute(trx);
|
||||
await trx.schema.alterTable('smart_search').dropConstraint('dim_size_constraint').ifExists().execute();
|
||||
});
|
||||
|
||||
await sql`vacuum analyze ${sql.table('smart_search')}`.execute(this.db);
|
||||
}
|
||||
|
||||
async deleteAllSearchEmbeddings(): Promise<void> {
|
||||
await sql`truncate ${sql.table('smart_search')}`.execute(this.db);
|
||||
}
|
||||
|
||||
async getCountries(userIds: string[]): Promise<string[]> {
|
||||
const res = await this.getExifField('country', userIds).execute();
|
||||
return res.map((row) => row.country!);
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
import { Kysely, sql } from 'kysely';
|
||||
import { DatabaseExtension } from 'src/enum';
|
||||
import { ConfigRepository } from 'src/repositories/config.repository';
|
||||
import { getVectorExtension } from 'src/repositories/database.repository';
|
||||
import { LoggingRepository } from 'src/repositories/logging.repository';
|
||||
import { vectorIndexQuery } from 'src/utils/database';
|
||||
|
||||
const vectorExtension = new ConfigRepository().getEnv().database.vectorExtension;
|
||||
const lastMigrationSql = sql<{ name: string }>`SELECT "name" FROM "migrations" ORDER BY "timestamp" DESC LIMIT 1;`;
|
||||
const tableExists = sql<{ result: string | null }>`select to_regclass('migrations') as "result"`;
|
||||
const logger = LoggingRepository.create();
|
||||
|
@ -25,12 +24,14 @@ export async function up(db: Kysely<any>): Promise<void> {
|
|||
return;
|
||||
}
|
||||
|
||||
const vectorExtension = await getVectorExtension(db);
|
||||
|
||||
await sql`CREATE EXTENSION IF NOT EXISTS "uuid-ossp";`.execute(db);
|
||||
await sql`CREATE EXTENSION IF NOT EXISTS "unaccent";`.execute(db);
|
||||
await sql`CREATE EXTENSION IF NOT EXISTS "cube";`.execute(db);
|
||||
await sql`CREATE EXTENSION IF NOT EXISTS "earthdistance";`.execute(db);
|
||||
await sql`CREATE EXTENSION IF NOT EXISTS "pg_trgm";`.execute(db);
|
||||
await sql`CREATE EXTENSION IF NOT EXISTS ${sql.raw(vectorExtension)}`.execute(db);
|
||||
await sql`CREATE EXTENSION IF NOT EXISTS ${sql.raw(vectorExtension)} CASCADE`.execute(db);
|
||||
await sql`CREATE OR REPLACE FUNCTION immich_uuid_v7(p_timestamp timestamp with time zone default clock_timestamp())
|
||||
RETURNS uuid
|
||||
VOLATILE LANGUAGE SQL
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import { EXTENSION_NAMES } from 'src/constants';
|
||||
import { DatabaseExtension } from 'src/enum';
|
||||
import { DatabaseExtension, VectorIndex } from 'src/enum';
|
||||
import { DatabaseService } from 'src/services/database.service';
|
||||
import { VectorExtension } from 'src/types';
|
||||
import { mockEnvData } from 'test/repositories/config.repository.mock';
|
||||
|
@ -47,8 +47,10 @@ describe(DatabaseService.name, () => {
|
|||
describe.each(<Array<{ extension: VectorExtension; extensionName: string }>>[
|
||||
{ extension: DatabaseExtension.VECTOR, extensionName: EXTENSION_NAMES[DatabaseExtension.VECTOR] },
|
||||
{ extension: DatabaseExtension.VECTORS, extensionName: EXTENSION_NAMES[DatabaseExtension.VECTORS] },
|
||||
{ extension: DatabaseExtension.VECTORCHORD, extensionName: EXTENSION_NAMES[DatabaseExtension.VECTORCHORD] },
|
||||
])('should work with $extensionName', ({ extension, extensionName }) => {
|
||||
beforeEach(() => {
|
||||
mocks.database.getVectorExtension.mockResolvedValue(extension);
|
||||
mocks.config.getEnv.mockReturnValue(
|
||||
mockEnvData({
|
||||
database: {
|
||||
|
@ -240,41 +242,32 @@ describe(DatabaseService.name, () => {
|
|||
});
|
||||
|
||||
it(`should reindex ${extension} indices if needed`, async () => {
|
||||
mocks.database.shouldReindex.mockResolvedValue(true);
|
||||
|
||||
await expect(sut.onBootstrap()).resolves.toBeUndefined();
|
||||
|
||||
expect(mocks.database.shouldReindex).toHaveBeenCalledTimes(2);
|
||||
expect(mocks.database.reindex).toHaveBeenCalledTimes(2);
|
||||
expect(mocks.database.reindexVectorsIfNeeded).toHaveBeenCalledExactlyOnceWith([
|
||||
VectorIndex.CLIP,
|
||||
VectorIndex.FACE,
|
||||
]);
|
||||
expect(mocks.database.reindexVectorsIfNeeded).toHaveBeenCalledTimes(1);
|
||||
expect(mocks.database.runMigrations).toHaveBeenCalledTimes(1);
|
||||
expect(mocks.logger.fatal).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it(`should throw an error if reindexing fails`, async () => {
|
||||
mocks.database.shouldReindex.mockResolvedValue(true);
|
||||
mocks.database.reindex.mockRejectedValue(new Error('Error reindexing'));
|
||||
mocks.database.reindexVectorsIfNeeded.mockRejectedValue(new Error('Error reindexing'));
|
||||
|
||||
await expect(sut.onBootstrap()).rejects.toBeDefined();
|
||||
|
||||
expect(mocks.database.shouldReindex).toHaveBeenCalledTimes(1);
|
||||
expect(mocks.database.reindex).toHaveBeenCalledTimes(1);
|
||||
expect(mocks.database.reindexVectorsIfNeeded).toHaveBeenCalledExactlyOnceWith([
|
||||
VectorIndex.CLIP,
|
||||
VectorIndex.FACE,
|
||||
]);
|
||||
expect(mocks.database.runMigrations).not.toHaveBeenCalled();
|
||||
expect(mocks.logger.fatal).not.toHaveBeenCalled();
|
||||
expect(mocks.logger.warn).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Could not run vector reindexing checks.'),
|
||||
);
|
||||
});
|
||||
|
||||
it(`should not reindex ${extension} indices if not needed`, async () => {
|
||||
mocks.database.shouldReindex.mockResolvedValue(false);
|
||||
|
||||
await expect(sut.onBootstrap()).resolves.toBeUndefined();
|
||||
|
||||
expect(mocks.database.shouldReindex).toHaveBeenCalledTimes(2);
|
||||
expect(mocks.database.reindex).toHaveBeenCalledTimes(0);
|
||||
expect(mocks.database.runMigrations).toHaveBeenCalledTimes(1);
|
||||
expect(mocks.logger.fatal).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
it('should skip migrations if DB_SKIP_MIGRATIONS=true', async () => {
|
||||
|
@ -300,23 +293,7 @@ describe(DatabaseService.name, () => {
|
|||
expect(mocks.database.runMigrations).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it(`should throw error if pgvector extension could not be created`, async () => {
|
||||
mocks.config.getEnv.mockReturnValue(
|
||||
mockEnvData({
|
||||
database: {
|
||||
config: {
|
||||
connectionType: 'parts',
|
||||
host: 'database',
|
||||
port: 5432,
|
||||
username: 'postgres',
|
||||
password: 'postgres',
|
||||
database: 'immich',
|
||||
},
|
||||
skipMigrations: true,
|
||||
vectorExtension: DatabaseExtension.VECTOR,
|
||||
},
|
||||
}),
|
||||
);
|
||||
it(`should throw error if extension could not be created`, async () => {
|
||||
mocks.database.getExtensionVersion.mockResolvedValue({
|
||||
installedVersion: null,
|
||||
availableVersion: minVersionInRange,
|
||||
|
@ -328,26 +305,7 @@ describe(DatabaseService.name, () => {
|
|||
|
||||
expect(mocks.logger.fatal).toHaveBeenCalledTimes(1);
|
||||
expect(mocks.logger.fatal.mock.calls[0][0]).toContain(
|
||||
`Alternatively, if your Postgres instance has pgvecto.rs, you may use this instead`,
|
||||
);
|
||||
expect(mocks.database.createExtension).toHaveBeenCalledTimes(1);
|
||||
expect(mocks.database.updateVectorExtension).not.toHaveBeenCalled();
|
||||
expect(mocks.database.runMigrations).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it(`should throw error if pgvecto.rs extension could not be created`, async () => {
|
||||
mocks.database.getExtensionVersion.mockResolvedValue({
|
||||
installedVersion: null,
|
||||
availableVersion: minVersionInRange,
|
||||
});
|
||||
mocks.database.updateVectorExtension.mockResolvedValue({ restartRequired: false });
|
||||
mocks.database.createExtension.mockRejectedValue(new Error('Failed to create extension'));
|
||||
|
||||
await expect(sut.onBootstrap()).rejects.toThrow('Failed to create extension');
|
||||
|
||||
expect(mocks.logger.fatal).toHaveBeenCalledTimes(1);
|
||||
expect(mocks.logger.fatal.mock.calls[0][0]).toContain(
|
||||
`Alternatively, if your Postgres instance has pgvector, you may use this instead`,
|
||||
`Alternatively, if your Postgres instance has any of vector, vectors, vchord, you may use one of them instead by setting the environment variable 'DB_VECTOR_EXTENSION=<extension name>'`,
|
||||
);
|
||||
expect(mocks.database.createExtension).toHaveBeenCalledTimes(1);
|
||||
expect(mocks.database.updateVectorExtension).not.toHaveBeenCalled();
|
||||
|
|
|
@ -6,7 +6,7 @@ import { BootstrapEventPriority, DatabaseExtension, DatabaseLock, VectorIndex }
|
|||
import { BaseService } from 'src/services/base.service';
|
||||
import { VectorExtension } from 'src/types';
|
||||
|
||||
type CreateFailedArgs = { name: string; extension: string; otherName: string };
|
||||
type CreateFailedArgs = { name: string; extension: string; otherExtensions: string[] };
|
||||
type UpdateFailedArgs = { name: string; extension: string; availableVersion: string };
|
||||
type RestartRequiredArgs = { name: string; availableVersion: string };
|
||||
type NightlyVersionArgs = { name: string; extension: string; version: string };
|
||||
|
@ -25,18 +25,15 @@ const messages = {
|
|||
outOfRange: ({ name, version, range }: OutOfRangeArgs) =>
|
||||
`The ${name} extension version is ${version}, but Immich only supports ${range}.
|
||||
Please change ${name} to a compatible version in the Postgres instance.`,
|
||||
createFailed: ({ name, extension, otherName }: CreateFailedArgs) =>
|
||||
createFailed: ({ name, extension, otherExtensions }: CreateFailedArgs) =>
|
||||
`Failed to activate ${name} extension.
|
||||
Please ensure the Postgres instance has ${name} installed.
|
||||
|
||||
If the Postgres instance already has ${name} installed, Immich may not have the necessary permissions to activate it.
|
||||
In this case, please run 'CREATE EXTENSION IF NOT EXISTS ${extension}' manually as a superuser.
|
||||
In this case, please run 'CREATE EXTENSION IF NOT EXISTS ${extension} CASCADE' manually as a superuser.
|
||||
See https://immich.app/docs/guides/database-queries for how to query the database.
|
||||
|
||||
Alternatively, if your Postgres instance has ${otherName}, you may use this instead by setting the environment variable 'DB_VECTOR_EXTENSION=${otherName}'.
|
||||
Note that switching between the two extensions after a successful startup is not supported.
|
||||
The exception is if your version of Immich prior to upgrading was 1.90.2 or earlier.
|
||||
In this case, you may set either extension now, but you will not be able to switch to the other extension following a successful startup.`,
|
||||
Alternatively, if your Postgres instance has any of ${otherExtensions.join(', ')}, you may use one of them instead by setting the environment variable 'DB_VECTOR_EXTENSION=<extension name>'.`,
|
||||
updateFailed: ({ name, extension, availableVersion }: UpdateFailedArgs) =>
|
||||
`The ${name} extension can be updated to ${availableVersion}.
|
||||
Immich attempted to update the extension, but failed to do so.
|
||||
|
@ -67,8 +64,7 @@ export class DatabaseService extends BaseService {
|
|||
}
|
||||
|
||||
await this.databaseRepository.withLock(DatabaseLock.Migrations, async () => {
|
||||
const envData = this.configRepository.getEnv();
|
||||
const extension = envData.database.vectorExtension;
|
||||
const extension = await this.databaseRepository.getVectorExtension();
|
||||
const name = EXTENSION_NAMES[extension];
|
||||
const extensionRange = this.databaseRepository.getExtensionVersionRange(extension);
|
||||
|
||||
|
@ -97,12 +93,23 @@ export class DatabaseService extends BaseService {
|
|||
throw new Error(messages.invalidDowngrade({ name, extension, availableVersion, installedVersion }));
|
||||
}
|
||||
|
||||
await this.checkReindexing();
|
||||
try {
|
||||
await this.databaseRepository.reindexVectorsIfNeeded([VectorIndex.CLIP, VectorIndex.FACE]);
|
||||
} catch (error) {
|
||||
this.logger.warn(
|
||||
'Could not run vector reindexing checks. If the extension was updated, please restart the Postgres instance. If you are upgrading directly from a version below 1.107.2, please upgrade to 1.107.2 first.',
|
||||
);
|
||||
throw error;
|
||||
}
|
||||
|
||||
const { database } = this.configRepository.getEnv();
|
||||
if (!database.skipMigrations) {
|
||||
await this.databaseRepository.runMigrations();
|
||||
}
|
||||
await Promise.all([
|
||||
this.databaseRepository.prewarm(VectorIndex.CLIP),
|
||||
this.databaseRepository.prewarm(VectorIndex.FACE),
|
||||
]);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -110,10 +117,13 @@ export class DatabaseService extends BaseService {
|
|||
try {
|
||||
await this.databaseRepository.createExtension(extension);
|
||||
} catch (error) {
|
||||
const otherExtension =
|
||||
extension === DatabaseExtension.VECTORS ? DatabaseExtension.VECTOR : DatabaseExtension.VECTORS;
|
||||
const otherExtensions = [
|
||||
DatabaseExtension.VECTOR,
|
||||
DatabaseExtension.VECTORS,
|
||||
DatabaseExtension.VECTORCHORD,
|
||||
].filter((ext) => ext !== extension);
|
||||
const name = EXTENSION_NAMES[extension];
|
||||
this.logger.fatal(messages.createFailed({ name, extension, otherName: EXTENSION_NAMES[otherExtension] }));
|
||||
this.logger.fatal(messages.createFailed({ name, extension, otherExtensions }));
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
@ -130,21 +140,4 @@ export class DatabaseService extends BaseService {
|
|||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
private async checkReindexing() {
|
||||
try {
|
||||
if (await this.databaseRepository.shouldReindex(VectorIndex.CLIP)) {
|
||||
await this.databaseRepository.reindex(VectorIndex.CLIP);
|
||||
}
|
||||
|
||||
if (await this.databaseRepository.shouldReindex(VectorIndex.FACE)) {
|
||||
await this.databaseRepository.reindex(VectorIndex.FACE);
|
||||
}
|
||||
} catch (error) {
|
||||
this.logger.warn(
|
||||
'Could not run vector reindexing checks. If the extension was updated, please restart the Postgres instance.',
|
||||
);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -33,6 +33,7 @@ import {
|
|||
QueueName,
|
||||
SourceType,
|
||||
SystemMetadataKey,
|
||||
VectorIndex,
|
||||
} from 'src/enum';
|
||||
import { BoundingBox } from 'src/repositories/machine-learning.repository';
|
||||
import { UpdateFacesData } from 'src/repositories/person.repository';
|
||||
|
@ -418,6 +419,8 @@ export class PersonService extends BaseService {
|
|||
return JobStatus.SKIPPED;
|
||||
}
|
||||
|
||||
await this.databaseRepository.prewarm(VectorIndex.FACE);
|
||||
|
||||
const lastRun = new Date().toISOString();
|
||||
const facePagination = this.personRepository.getAllFaces(
|
||||
force ? undefined : { personId: null, sourceType: SourceType.MACHINE_LEARNING },
|
||||
|
|
|
@ -54,28 +54,28 @@ describe(SmartInfoService.name, () => {
|
|||
it('should return if machine learning is disabled', async () => {
|
||||
await sut.onConfigInit({ newConfig: systemConfigStub.machineLearningDisabled as SystemConfig });
|
||||
|
||||
expect(mocks.search.getDimensionSize).not.toHaveBeenCalled();
|
||||
expect(mocks.search.setDimensionSize).not.toHaveBeenCalled();
|
||||
expect(mocks.search.deleteAllSearchEmbeddings).not.toHaveBeenCalled();
|
||||
expect(mocks.database.getDimensionSize).not.toHaveBeenCalled();
|
||||
expect(mocks.database.setDimensionSize).not.toHaveBeenCalled();
|
||||
expect(mocks.database.deleteAllSearchEmbeddings).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return if model and DB dimension size are equal', async () => {
|
||||
mocks.search.getDimensionSize.mockResolvedValue(512);
|
||||
mocks.database.getDimensionSize.mockResolvedValue(512);
|
||||
|
||||
await sut.onConfigInit({ newConfig: systemConfigStub.machineLearningEnabled as SystemConfig });
|
||||
|
||||
expect(mocks.search.getDimensionSize).toHaveBeenCalledTimes(1);
|
||||
expect(mocks.search.setDimensionSize).not.toHaveBeenCalled();
|
||||
expect(mocks.search.deleteAllSearchEmbeddings).not.toHaveBeenCalled();
|
||||
expect(mocks.database.getDimensionSize).toHaveBeenCalledTimes(1);
|
||||
expect(mocks.database.setDimensionSize).not.toHaveBeenCalled();
|
||||
expect(mocks.database.deleteAllSearchEmbeddings).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should update DB dimension size if model and DB have different values', async () => {
|
||||
mocks.search.getDimensionSize.mockResolvedValue(768);
|
||||
mocks.database.getDimensionSize.mockResolvedValue(768);
|
||||
|
||||
await sut.onConfigInit({ newConfig: systemConfigStub.machineLearningEnabled as SystemConfig });
|
||||
|
||||
expect(mocks.search.getDimensionSize).toHaveBeenCalledTimes(1);
|
||||
expect(mocks.search.setDimensionSize).toHaveBeenCalledWith(512);
|
||||
expect(mocks.database.getDimensionSize).toHaveBeenCalledTimes(1);
|
||||
expect(mocks.database.setDimensionSize).toHaveBeenCalledWith(512);
|
||||
});
|
||||
});
|
||||
|
||||
|
@ -89,13 +89,13 @@ describe(SmartInfoService.name, () => {
|
|||
});
|
||||
|
||||
expect(mocks.systemMetadata.get).not.toHaveBeenCalled();
|
||||
expect(mocks.search.getDimensionSize).not.toHaveBeenCalled();
|
||||
expect(mocks.search.setDimensionSize).not.toHaveBeenCalled();
|
||||
expect(mocks.search.deleteAllSearchEmbeddings).not.toHaveBeenCalled();
|
||||
expect(mocks.database.getDimensionSize).not.toHaveBeenCalled();
|
||||
expect(mocks.database.setDimensionSize).not.toHaveBeenCalled();
|
||||
expect(mocks.database.deleteAllSearchEmbeddings).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return if model and DB dimension size are equal', async () => {
|
||||
mocks.search.getDimensionSize.mockResolvedValue(512);
|
||||
mocks.database.getDimensionSize.mockResolvedValue(512);
|
||||
|
||||
await sut.onConfigUpdate({
|
||||
newConfig: {
|
||||
|
@ -106,13 +106,13 @@ describe(SmartInfoService.name, () => {
|
|||
} as SystemConfig,
|
||||
});
|
||||
|
||||
expect(mocks.search.getDimensionSize).toHaveBeenCalledTimes(1);
|
||||
expect(mocks.search.setDimensionSize).not.toHaveBeenCalled();
|
||||
expect(mocks.search.deleteAllSearchEmbeddings).not.toHaveBeenCalled();
|
||||
expect(mocks.database.getDimensionSize).toHaveBeenCalledTimes(1);
|
||||
expect(mocks.database.setDimensionSize).not.toHaveBeenCalled();
|
||||
expect(mocks.database.deleteAllSearchEmbeddings).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should update DB dimension size if model and DB have different values', async () => {
|
||||
mocks.search.getDimensionSize.mockResolvedValue(512);
|
||||
mocks.database.getDimensionSize.mockResolvedValue(512);
|
||||
|
||||
await sut.onConfigUpdate({
|
||||
newConfig: {
|
||||
|
@ -123,12 +123,12 @@ describe(SmartInfoService.name, () => {
|
|||
} as SystemConfig,
|
||||
});
|
||||
|
||||
expect(mocks.search.getDimensionSize).toHaveBeenCalledTimes(1);
|
||||
expect(mocks.search.setDimensionSize).toHaveBeenCalledWith(768);
|
||||
expect(mocks.database.getDimensionSize).toHaveBeenCalledTimes(1);
|
||||
expect(mocks.database.setDimensionSize).toHaveBeenCalledWith(768);
|
||||
});
|
||||
|
||||
it('should clear embeddings if old and new models are different', async () => {
|
||||
mocks.search.getDimensionSize.mockResolvedValue(512);
|
||||
mocks.database.getDimensionSize.mockResolvedValue(512);
|
||||
|
||||
await sut.onConfigUpdate({
|
||||
newConfig: {
|
||||
|
@ -139,9 +139,9 @@ describe(SmartInfoService.name, () => {
|
|||
} as SystemConfig,
|
||||
});
|
||||
|
||||
expect(mocks.search.deleteAllSearchEmbeddings).toHaveBeenCalled();
|
||||
expect(mocks.search.getDimensionSize).toHaveBeenCalledTimes(1);
|
||||
expect(mocks.search.setDimensionSize).not.toHaveBeenCalled();
|
||||
expect(mocks.database.deleteAllSearchEmbeddings).toHaveBeenCalled();
|
||||
expect(mocks.database.getDimensionSize).toHaveBeenCalledTimes(1);
|
||||
expect(mocks.database.setDimensionSize).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
|
@ -151,7 +151,7 @@ describe(SmartInfoService.name, () => {
|
|||
|
||||
await sut.handleQueueEncodeClip({});
|
||||
|
||||
expect(mocks.search.setDimensionSize).not.toHaveBeenCalled();
|
||||
expect(mocks.database.setDimensionSize).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should queue the assets without clip embeddings', async () => {
|
||||
|
@ -163,7 +163,7 @@ describe(SmartInfoService.name, () => {
|
|||
{ name: JobName.SMART_SEARCH, data: { id: assetStub.image.id } },
|
||||
]);
|
||||
expect(mocks.assetJob.streamForEncodeClip).toHaveBeenCalledWith(false);
|
||||
expect(mocks.search.setDimensionSize).not.toHaveBeenCalled();
|
||||
expect(mocks.database.setDimensionSize).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should queue all the assets', async () => {
|
||||
|
@ -175,7 +175,7 @@ describe(SmartInfoService.name, () => {
|
|||
{ name: JobName.SMART_SEARCH, data: { id: assetStub.image.id } },
|
||||
]);
|
||||
expect(mocks.assetJob.streamForEncodeClip).toHaveBeenCalledWith(true);
|
||||
expect(mocks.search.setDimensionSize).toHaveBeenCalledExactlyOnceWith(512);
|
||||
expect(mocks.database.setDimensionSize).toHaveBeenCalledExactlyOnceWith(512);
|
||||
});
|
||||
});
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ export class SmartInfoService extends BaseService {
|
|||
|
||||
await this.databaseRepository.withLock(DatabaseLock.CLIPDimSize, async () => {
|
||||
const { dimSize } = getCLIPModelInfo(newConfig.machineLearning.clip.modelName);
|
||||
const dbDimSize = await this.searchRepository.getDimensionSize();
|
||||
const dbDimSize = await this.databaseRepository.getDimensionSize('smart_search');
|
||||
this.logger.verbose(`Current database CLIP dimension size is ${dbDimSize}`);
|
||||
|
||||
const modelChange =
|
||||
|
@ -53,10 +53,10 @@ export class SmartInfoService extends BaseService {
|
|||
`Dimension size of model ${newConfig.machineLearning.clip.modelName} is ${dimSize}, but database expects ${dbDimSize}.`,
|
||||
);
|
||||
this.logger.log(`Updating database CLIP dimension size to ${dimSize}.`);
|
||||
await this.searchRepository.setDimensionSize(dimSize);
|
||||
await this.databaseRepository.setDimensionSize(dimSize);
|
||||
this.logger.log(`Successfully updated database CLIP dimension size from ${dbDimSize} to ${dimSize}.`);
|
||||
} else {
|
||||
await this.searchRepository.deleteAllSearchEmbeddings();
|
||||
await this.databaseRepository.deleteAllSearchEmbeddings();
|
||||
}
|
||||
|
||||
// TODO: A job to reindex all assets should be scheduled, though user
|
||||
|
@ -74,7 +74,7 @@ export class SmartInfoService extends BaseService {
|
|||
if (force) {
|
||||
const { dimSize } = getCLIPModelInfo(machineLearning.clip.modelName);
|
||||
// in addition to deleting embeddings, update the dimension size in case it failed earlier
|
||||
await this.searchRepository.setDimensionSize(dimSize);
|
||||
await this.databaseRepository.setDimensionSize(dimSize);
|
||||
}
|
||||
|
||||
let queue: JobItem[] = [];
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import { SystemConfig } from 'src/config';
|
||||
import { VECTOR_EXTENSIONS } from 'src/constants';
|
||||
import {
|
||||
AssetType,
|
||||
DatabaseExtension,
|
||||
DatabaseSslMode,
|
||||
ExifOrientation,
|
||||
ImageFormat,
|
||||
|
@ -363,7 +363,7 @@ export type JobItem =
|
|||
// Version check
|
||||
| { name: JobName.VERSION_CHECK; data: IBaseJob };
|
||||
|
||||
export type VectorExtension = DatabaseExtension.VECTOR | DatabaseExtension.VECTORS;
|
||||
export type VectorExtension = (typeof VECTOR_EXTENSIONS)[number];
|
||||
|
||||
export type DatabaseConnectionURL = {
|
||||
connectionType: 'url';
|
||||
|
|
|
@ -384,14 +384,28 @@ export function searchAssetBuilder(kysely: Kysely<DB>, options: AssetSearchBuild
|
|||
.$if(!options.withDeleted, (qb) => qb.where('assets.deletedAt', 'is', null));
|
||||
}
|
||||
|
||||
type VectorIndexOptions = { vectorExtension: VectorExtension; table: string; indexName: string };
|
||||
export type ReindexVectorIndexOptions = { indexName: string; lists?: number };
|
||||
|
||||
export function vectorIndexQuery({ vectorExtension, table, indexName }: VectorIndexOptions): string {
|
||||
type VectorIndexQueryOptions = { table: string; vectorExtension: VectorExtension } & ReindexVectorIndexOptions;
|
||||
|
||||
export function vectorIndexQuery({ vectorExtension, table, indexName, lists }: VectorIndexQueryOptions): string {
|
||||
switch (vectorExtension) {
|
||||
case DatabaseExtension.VECTORCHORD: {
|
||||
return `
|
||||
CREATE INDEX IF NOT EXISTS ${indexName} ON ${table} USING vchordrq (embedding vector_cosine_ops) WITH (options = $$
|
||||
residual_quantization = false
|
||||
[build.internal]
|
||||
lists = [${lists ?? 1}]
|
||||
spherical_centroids = true
|
||||
build_threads = 4
|
||||
sampling_factor = 1024
|
||||
$$)`;
|
||||
}
|
||||
case DatabaseExtension.VECTORS: {
|
||||
return `
|
||||
CREATE INDEX IF NOT EXISTS ${indexName} ON ${table}
|
||||
USING vectors (embedding vector_cos_ops) WITH (options = $$
|
||||
optimizing.optimizing_threads = 4
|
||||
[indexing.hnsw]
|
||||
m = 16
|
||||
ef_construction = 300
|
||||
|
|
|
@ -170,7 +170,7 @@ export const getRepository = <K extends keyof RepositoriesTypes>(key: K, db: Kys
|
|||
}
|
||||
|
||||
case 'search': {
|
||||
return new SearchRepository(db, new ConfigRepository());
|
||||
return new SearchRepository(db);
|
||||
}
|
||||
|
||||
case 'session': {
|
||||
|
|
|
@ -7,7 +7,7 @@ import { getKyselyConfig } from 'src/utils/database';
|
|||
import { GenericContainer, Wait } from 'testcontainers';
|
||||
|
||||
const globalSetup = async () => {
|
||||
const postgresContainer = await new GenericContainer('tensorchord/pgvecto-rs:pg14-v0.2.0')
|
||||
const postgresContainer = await new GenericContainer('ghcr.io/immich-app/postgres:14')
|
||||
.withExposedPorts(5432)
|
||||
.withEnvironment({
|
||||
POSTGRES_PASSWORD: 'postgres',
|
||||
|
@ -17,9 +17,7 @@ const globalSetup = async () => {
|
|||
.withCommand([
|
||||
'postgres',
|
||||
'-c',
|
||||
'shared_preload_libraries=vectors.so',
|
||||
'-c',
|
||||
'search_path="$$user", public, vectors',
|
||||
'shared_preload_libraries=vchord.so',
|
||||
'-c',
|
||||
'max_wal_size=2GB',
|
||||
'-c',
|
||||
|
@ -30,6 +28,8 @@ const globalSetup = async () => {
|
|||
'full_page_writes=off',
|
||||
'-c',
|
||||
'synchronous_commit=off',
|
||||
'-c',
|
||||
'config_file=/var/lib/postgresql/data/postgresql.conf',
|
||||
])
|
||||
.withWaitStrategy(Wait.forAll([Wait.forLogMessage('database system is ready to accept connections', 2)]))
|
||||
.start();
|
||||
|
|
|
@ -6,13 +6,17 @@ export const newDatabaseRepositoryMock = (): Mocked<RepositoryInterface<Database
|
|||
return {
|
||||
shutdown: vitest.fn(),
|
||||
getExtensionVersion: vitest.fn(),
|
||||
getVectorExtension: vitest.fn(),
|
||||
getExtensionVersionRange: vitest.fn(),
|
||||
getPostgresVersion: vitest.fn().mockResolvedValue('14.10 (Debian 14.10-1.pgdg120+1)'),
|
||||
getPostgresVersionRange: vitest.fn().mockReturnValue('>=14.0.0'),
|
||||
createExtension: vitest.fn().mockResolvedValue(void 0),
|
||||
updateVectorExtension: vitest.fn(),
|
||||
reindex: vitest.fn(),
|
||||
shouldReindex: vitest.fn(),
|
||||
reindexVectorsIfNeeded: vitest.fn(),
|
||||
getDimensionSize: vitest.fn(),
|
||||
setDimensionSize: vitest.fn(),
|
||||
deleteAllSearchEmbeddings: vitest.fn(),
|
||||
prewarm: vitest.fn(),
|
||||
runMigrations: vitest.fn(),
|
||||
withLock: vitest.fn().mockImplementation((_, function_: <R>() => Promise<R>) => function_()),
|
||||
tryLock: vitest.fn(),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue