Source: utils/embeddings.js

/**
 * AI Embeddings Service
 * Supports multiple backends: OpenAI, Ollama, or self-hosted TEI
 */

import crypto from 'crypto';
import { query, UUID2hex, HEX2uuid } from '@commtool/sql-query';
import { errorLogger } from './requestLogger.js';
import { sanitizeTextForEmbedding } from './santitizePII.js';

// Model configs
const MODELS = {
  openai: {
    name: 'text-embedding-3-small',
    dimension: 768, // Reduced from 1536 - OpenAI v3 models support dimension reduction with minimal quality loss
    maxTokens: 8191
  },
  ollama: {
    name: 'nomic-embed-text',
    dimension: 768,
    maxTokens: 8192
  },
  tei: {
    name: 'intfloat/multilingual-e5-large',
    dimension: 1024,
    maxTokens: 512
  }
};


// Configuration getters (lazy evaluation for Vault-loaded env vars)
const getEmbeddingBackend = () => process.env.EMBEDDING_BACKEND || 'openai';
const getOpenAIKey = () => process.env.OPENAI_API_KEY;
const getEmbeddingServerURL = () => process.env.EMBEDDING_SERVER_URL || 'http://localhost:8088';
const getEmbeddingDimension = () => parseInt(process.env.EMBEDDING_DIMENSION || MODELS[getEmbeddingBackend()].dimension);
const getCurrentModel = () => ({
  ...MODELS[getEmbeddingBackend()],
  dimension: getEmbeddingDimension() // Allow override via env var
});

/**
 * Generate embeddings using OpenAI
 */
async function generateEmbeddingsOpenAI(texts) {
  const apiKey = getOpenAIKey();
  if (!apiKey) {
    throw new Error('OPENAI_API_KEY environment variable not set');
  }

  const currentModel = getCurrentModel();
  const response = await fetch('https://api.openai.com/v1/embeddings', {
    method: 'POST',
    headers: {
      'Content-Type': 'application/json',
      'Authorization': `Bearer ${apiKey}`
    },
    body: JSON.stringify({
      input: texts,
      model: currentModel.name,
      dimensions: currentModel.dimension
    })
  });

  if (!response.ok) {
    const error = await response.text();
    throw new Error(`OpenAI API error: ${response.status} ${error}`);
  }

  const data = await response.json();
  return data.data.map(item => new Float32Array(item.embedding));
}

/**
 * Generate embeddings using Ollama
 */
async function generateEmbeddingsOllama(texts) {
  const embeddings = [];
  const serverURL = getEmbeddingServerURL();
  const currentModel = getCurrentModel();
  
  for (const text of texts) {
    const response = await fetch(`${serverURL}/api/embeddings`, {
      method: 'POST',
      headers: { 'Content-Type': 'application/json' },
      body: JSON.stringify({
        model: currentModel.name,
        prompt: text
      })
    });

    if (!response.ok) {
      throw new Error(`Ollama error: ${response.status}`);
    }

    const data = await response.json();
    embeddings.push(new Float32Array(data.embedding));
  }

  return embeddings;
}

/**
 * Generate embeddings using TEI (HuggingFace Text Embeddings Inference)
 */
async function generateEmbeddingsTEI(texts) {
  const serverURL = getEmbeddingServerURL();
  const response = await fetch(`${serverURL}/embed`, {
    method: 'POST',
    headers: { 'Content-Type': 'application/json' },
    body: JSON.stringify({ inputs: texts })
  });

  if (!response.ok) {
    throw new Error(`TEI server error: ${response.status}`);
  }

  const embeddings = await response.json();
  return embeddings.map(embedding => new Float32Array(embedding));
}

/**
 * Generate embeddings from text using configured backend
 * @param {string|string[]} texts - Text(s) to embed
 * @param {Object} options - Generation options
 * @param {boolean} [options.sanitize=true] - Sanitize sensitive PII before embedding
 * @param {Object} [options.sanitizeOptions] - Options passed to sanitizeTextForEmbedding
 * @returns {Promise<Float32Array[]>} - Array of embedding vectors
 */
export async function generateEmbeddings(texts, options = {}) {
  const { sanitize = true, sanitizeOptions = {} } = options;
  
  let inputs = Array.isArray(texts) ? texts : [texts];
  
  // Sanitize texts if enabled
  if (sanitize) {
    inputs = inputs.map(text => sanitizeTextForEmbedding(text, sanitizeOptions));
  }
  
  const backend = getEmbeddingBackend();
  
  try {
    switch (backend) {
      case 'openai':
        return await generateEmbeddingsOpenAI(inputs);
      case 'ollama':
        return await generateEmbeddingsOllama(inputs);
      case 'tei':
        return await generateEmbeddingsTEI(inputs);
      default:
        throw new Error(`Unknown embedding backend: ${backend}`);
    }
  } catch (error) {
    errorLogger(error);
    throw error;
  }
}

/**
 * Generate single embedding from text
 * @param {string} text - Text to embed
 * @param {Object} options - Generation options (see generateEmbeddings)
 * @returns {Promise<Float32Array>} - Embedding vector
 */
export async function generateEmbedding(text, options = {}) {
  const embeddings = await generateEmbeddings([text], options);
  return embeddings[0];
}

/**
 * Convert Float32Array to Buffer for database storage
 * @param {Float32Array} embedding - Embedding vector
 * @returns {Buffer} - Binary buffer
 */
export function embeddingToBuffer(embedding) {
  return Buffer.from(embedding.buffer);
}


/**
 * Convert Buffer from database to Float32Array
 * @param {Buffer} buffer - Binary buffer from DB
 * @returns {Float32Array} - Embedding vector
 */
export function bufferToEmbedding(buffer) {
  return new Float32Array(buffer.buffer, buffer.byteOffset, buffer.byteLength / 4);
}

/**
 * Calculate cosine similarity between two embeddings
 * @param {Float32Array} a - First embedding
 * @param {Float32Array} b - Second embedding
 * @returns {number} - Similarity score (0-1)
 */
export function cosineSimilarity(a, b) {
  if (a.length !== b.length) {
    throw new Error('Embeddings must have same dimension');
  }

  let dotProduct = 0;
  let normA = 0;
  let normB = 0;

  for (let i = 0; i < a.length; i++) {
    dotProduct += a[i] * b[i];
    normA += a[i] * a[i];
    normB += b[i] * b[i];
  }

  return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
}

/**
 * Calculate content hash for change detection
 * @param {string} text - Text to hash
 * @returns {Buffer} - SHA256 hash
 */
export function calculateContentHash(text) {
  return crypto.createHash('sha256').update(text).digest();
}

/**
 * Check if embedding server is healthy
 * @returns {Promise<boolean>}
 */
export async function isEmbeddingServerHealthy() {
  try {
    const backend = getEmbeddingBackend();
    if (backend === 'openai') {
      // Check OpenAI API with a minimal request
      return !!getOpenAIKey();
    } else if (backend === 'ollama') {
      const response = await fetch(`${getEmbeddingServerURL()}/api/tags`);
      return response.ok;
    } else {
      const response = await fetch(`${getEmbeddingServerURL()}/health`);
      return response.ok;
    }
  } catch (error) {
    errorLogger(error);
    return false;
  }
}

/**
 * Get embedding server info
 * @returns {Promise<Object>}
 */
export async function getEmbeddingServerInfo() {
  try {
    const backend = getEmbeddingBackend();
    const currentModel = getCurrentModel();
    if (backend === 'openai') {
      return {
        backend: 'OpenAI',
        model: currentModel.name,
        dimension: currentModel.dimension
      };
    } else if (backend === 'ollama') {
      const response = await fetch(`${getEmbeddingServerURL()}/api/tags`);
      return await response.json();
    } else {
      const response = await fetch(`${getEmbeddingServerURL()}/info`);
      return await response.json();
    }
  } catch (error) {
    errorLogger(error);
    return null;
  }
}


/**
 * Updates or creates AI embeddings for any entity type asynchronously
 * 
 * @param {Buffer} entityUID - Entity UID as buffer
 * @param {Object|string} data - Entity data object or pre-prepared text
 * @param {string} entityType - Entity type ('person', 'extern', 'group', etc.)
 * @param {Buffer} organizationUID - Organization UID as buffer
 * @param {Object} [options] - Optional configuration
 * @param {boolean} [options.sanitize=false] - Sanitize PII for person/extern data
 * @param {Function} [options.textBuilder] - Custom function to build embedding text from data
 * @returns {Promise<void>}
 */
export async function updateEntityEmbedding(entityUID, data, entityType, organizationUID, userUID, options = {}) {
    try {
        const { sanitize = false, textBuilder = null } = options;
        
        let embeddingText;
        
        if (typeof data === 'string') {
            // Pre-prepared text
            embeddingText = data;
        } else if (textBuilder) {
            // Custom text builder function
            embeddingText = textBuilder(data);
        } else if (sanitize && ['person', 'extern','guest','job'].includes(entityType)) {
            // Sanitize PII for person/extern data
            const { default: sanitizePII } = await import('./santitizePII.js');
            const result = sanitizePII(data);
            embeddingText = result.embeddingText;
        } else {
            // Default: JSON stringify
            embeddingText = JSON.stringify(data);
        }

        const contentHash = calculateContentHash(embeddingText);
        const currentModel = getCurrentModel();

        // Check if embedding already exists and is up-to-date
        // UID is the entity UID itself - one embedding per entity
        const existing = await query(`
            SELECT UID, ContentHash, EmbeddingModel
            FROM AIEmbeddings 
            WHERE UID = ?
            LIMIT 1
        `, [entityUID]);
        // Skip if content hasn't changed AND model is the same
        if (existing.length > 0 && 
            existing[0].ContentHash.equals(contentHash) && 
            existing[0].EmbeddingModel === currentModel.name) {
            return;
        }

        // Generate embedding
        const embedding = await generateEmbedding(embeddingText);
        const embeddingBuffer = embeddingToBuffer(embedding);

        // Insert or update embedding (overwrite on model change)
        if (existing.length > 0) {
            await query(`
                UPDATE AIEmbeddings 
                SET Embedding = ?,
                    ContentHash = ?,
                    EmbeddingModel = ?,
                    EmbeddingDimension = ?,
                    EntityType = ?,
                    CreatedAt = CURRENT_TIMESTAMP(6)
                WHERE UID = ?
            `, [embeddingBuffer, contentHash, currentModel.name, currentModel.dimension, entityType,  entityUID]);
        } else {
            await query(`
                INSERT INTO AIEmbeddings 
                (UID, EntityType, EmbeddingModel, EmbeddingDimension, Embedding, ContentHash, UIDOrganization, CreatedAt)
                VALUES (?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP(6))
            `, [entityUID, entityType, currentModel.name, currentModel.dimension, embeddingBuffer, contentHash,UUID2hex(organizationUID)]);
        }
    } catch (error) {
        // Log error but don't throw - embedding generation should not break the main flow
        errorLogger(error);
    }
}

// Export configuration getter
export const EMBEDDING_CONFIG = {
  get backend() { return getEmbeddingBackend(); },
  get model() { return getCurrentModel(); },
  get dimension() { return getCurrentModel().dimension; },
  get serverUrl() { return getEmbeddingBackend() === 'openai' ? 'https://api.openai.com' : getEmbeddingServerURL(); }
};

/**
 * Search for similar entities using vector similarity
 * 
 * @param {string} searchText - Text to search for
 * @param {Object} options - Search options
 * @param {string} [options.entityType] - Filter by entity type
 * @param {Buffer|string} [options.organizationUID] - Filter by organization (Buffer or hex string)
 * @param {number} [options.limit=10] - Maximum number of results
 * @param {number} [options.threshold=0.7] - Minimum similarity threshold (0-1)
 * @param {string} [options.distanceMetric='cosine'] - Distance metric: 'cosine' or 'euclidean'
 * @param {boolean} [options.sanitize=true] - Sanitize sensitive PII in search text
 * @param {Object} [options.sanitizeOptions] - Options passed to sanitizeTextForEmbedding
 * @returns {Promise<Array>} - Array of matching entities with similarity scores
 */
export async function searchSimilarEntities(searchText, options = {}) {
    const {
        entityType,
        organizationUID,
        limit = 10,
        threshold = 0.7,
        distanceMetric = 'cosine',
        sanitize = true,
        sanitizeOptions = {}
    } = options;

    try {
        // Generate embedding for search text (with optional sanitization)
        const searchEmbedding = await generateEmbedding(searchText, { sanitize, sanitizeOptions });
        const searchBuffer = embeddingToBuffer(searchEmbedding);

        // Build query with filters
        const conditions = [];
        const params = [searchBuffer];

        if (entityType) {
            conditions.push('e.EntityType = ?');
            params.push(entityType);
        }

        if (organizationUID) {
            conditions.push('e.UIDOrganization = ?');
            // Convert hex string to buffer if needed
            const orgUID = typeof organizationUID === 'string' ? UUID2hex(organizationUID) : organizationUID;
            params.push(orgUID);
        }

        const whereClause = conditions.length > 0 ? `WHERE ${conditions.join(' AND ')}` : '';

        // Choose distance function based on metric
        const distanceFunc = distanceMetric === 'euclidean' 
            ? 'VEC_DISTANCE_EUCLIDEAN' 
            : 'VEC_DISTANCE_COSINE';

        // Query using MariaDB vector distance functions
        const sql = `
            SELECT 
                e.UID,
                e.TUID,
                e.EntityType,
                e.EmbeddingModel,
                e.UIDOrganization,
                e.CreatedAt,
                ${distanceFunc}(?, e.Embedding) AS distance,
                ${distanceMetric === 'cosine' ? '1 - ' : ''}${distanceFunc}(?, e.Embedding) AS similarity,
                o.Title,
                o.Display,
                o.Type AS ObjectType
            FROM AIEmbeddings e
            LEFT JOIN ObjectBase o ON e.UID = o.UID AND o.ValidUntil = '2038-01-19 03:14:07.999999'
            ${whereClause}
            ORDER BY distance ASC
            LIMIT ?
        `;

        params.push(searchBuffer); // Second parameter for similarity calculation
        params.push(limit);

        const results = await query(sql, params);

        // Filter by threshold and convert buffers
        return results
            .filter(row => row.similarity >= threshold)
            .map(row => ({
                uid: HEX2uuid(row.UID),
                tuid: row.TUID,
                entityType: row.EntityType,
                objectType: row.ObjectType,
                title: row.Title,
                display: row.Display,
                similarity: parseFloat(row.similarity.toFixed(4)),
                distance: parseFloat(row.distance.toFixed(4)),
                model: row.EmbeddingModel,
                organizationUID: HEX2uuid(row.UIDOrganization),
                createdAt: row.CreatedAt
            }));

    } catch (error) {
        errorLogger(error);
        throw error;
    }
}

/**
 * Find entities similar to a reference entity
 * 
 * @param {Buffer|string} referenceUID - UID of the reference entity (Buffer or hex string)
 * @param {Object} options - Search options (same as searchSimilarEntities)
 * @returns {Promise<Array>} - Array of similar entities
 */
export async function findSimilarEntities(referenceUID, options = {}) {
    try {
        // Convert hex string to buffer if needed
        const uid = typeof referenceUID === 'string' ? UUID2hex(referenceUID) : referenceUID;

        // Get the reference entity's embedding
        const reference = await query(`
            SELECT Embedding, EntityType, UIDOrganization
            FROM AIEmbeddings
            WHERE UID = ?
            LIMIT 1
        `, [uid]);

        if (reference.length === 0) {
            throw new Error('Reference entity embedding not found');
        }

        const { limit = 10, threshold = 0.7, distanceMetric = 'cosine' } = options;
        
        // Build query to find similar entities
        const conditions = ['e.UID != ?'];
        const params = [uid];

        // Inherit filters from reference if not explicitly provided
        if (options.entityType !== undefined) {
            conditions.push('e.EntityType = ?');
            params.push(options.entityType);
        } else if (options.inheritType !== false) {
            conditions.push('e.EntityType = ?');
            params.push(reference[0].EntityType);
        }

        if (options.organizationUID !== undefined) {
            const orgUID = typeof options.organizationUID === 'string' 
                ? UUID2hex(options.organizationUID) 
                : options.organizationUID;
            conditions.push('e.UIDOrganization = ?');
            params.push(orgUID);
        } else if (options.inheritOrganization !== false) {
            conditions.push('e.UIDOrganization = ?');
            params.push(reference[0].UIDOrganization);
        }

        const whereClause = `WHERE ${conditions.join(' AND ')}`;
        const distanceFunc = distanceMetric === 'euclidean' 
            ? 'VEC_DISTANCE_EUCLIDEAN' 
            : 'VEC_DISTANCE_COSINE';

        const sql = `
            SELECT 
                e.UID,
                e.TUID,
                e.EntityType,
                e.EmbeddingModel,
                e.UIDOrganization,
                e.CreatedAt,
                ${distanceFunc}(?, e.Embedding) AS distance,
                ${distanceMetric === 'cosine' ? '1 - ' : ''}${distanceFunc}(?, e.Embedding) AS similarity,
                o.Title,
                o.Display,
                o.Type AS ObjectType
            FROM AIEmbeddings e
            LEFT JOIN ObjectBase o ON e.UID = o.UID AND o.ValidUntil = '2038-01-19 03:14:07.999999'
            ${whereClause}
            ORDER BY distance ASC
            LIMIT ?
        `;

        params.unshift(reference[0].Embedding); // Add embedding as first two params
        params.unshift(reference[0].Embedding);
        params.push(limit);

        const results = await query(sql, params);

        return results
            .filter(row => row.similarity >= threshold)
            .map(row => ({
                uid: HEX2uuid(row.UID),
                tuid: row.TUID,
                entityType: row.EntityType,
                objectType: row.ObjectType,
                title: row.Title,
                display: row.Display,
                similarity: parseFloat(row.similarity.toFixed(4)),
                distance: parseFloat(row.distance.toFixed(4)),
                model: row.EmbeddingModel,
                organizationUID: HEX2uuid(row.UIDOrganization),
                createdAt: row.CreatedAt
            }));

    } catch (error) {
        errorLogger(error);
        throw error;
    }
}