Source code for now.executor.indexer.elastic.es_query_building

from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple, Union

from docarray import Document, DocumentArray

from now.utils.docarray.helpers import get_chunk_by_field_name

metrics_mapping = {
    'cosine': 'cosineSimilarity',
    'l2_norm': 'l2norm',
}


[docs]def generate_score_calculation( docs_map: Dict[str, DocumentArray], encoder_to_fields: Dict[str, Union[List[str], str]], ) -> List[List]: """ Generate score calculation from document mappings. :param docs_map: dictionary mapping encoder to DocumentArray. :param encoder_to_fields: dictionary mapping encoder to fields. :return: a list of score calculation, each of which is a tuple of (query_field, document_field, encoder, linear_weight). score calculation would then be for example: [('query_text', 'title', 'clip', 1.0)] """ score_calculation = [] for executor_name, da in docs_map.items(): first_doc = da[0] field_names = first_doc._metadata['multi_modal_schema'].keys() try: document_fields = encoder_to_fields[executor_name] except KeyError as e: raise KeyError( f'Documents are not encoded with same encoder as query. executor_name: {executor_name}, encoder_to_fields: {encoder_to_fields}' ) from e for field_name in field_names: chunk = get_chunk_by_field_name(first_doc, field_name) if chunk.chunks.embeddings is None and chunk.embedding is None: continue for document_field in document_fields: score_calculation.append( [ field_name, document_field, executor_name, 1, ] ) return score_calculation
[docs]def build_es_queries( docs_map, get_score_breakdown: bool, score_calculation: List[Tuple], metric: Optional[str] = 'cosine', filter: dict = {}, query_to_curated_ids: Dict[str, list] = {}, ) -> Dict: """ Build script-score query used in Elasticsearch. To do this, we extract embeddings from the query document and pass them in the script-score query together with the fields to search on in the Elasticsearch index. The query document will be returned with all of its embeddings as tags with their corresponding field+encoder as key. :param docs_map: dictionary mapping encoder to DocumentArray. :param get_score_breakdown: whether to return the score breakdown for matches. For this function, this parameter determines whether to return the embeddings of a query document. :param score_calculation: list of nested lists containing (query_field, document_field, matching_method, linear_weight) which define how to calculate the score. Note, that the matching_method is the name of the encoder or `bm25`. :param metric: metric to use for vector search. :param filter: dictionary of filters to apply to the search. :param query_to_curated_ids: dictionary mapping query text to list of curated ids. :return: a dictionary containing query and filter. """ queries = {} pinned_queries = {} docs = {} sources = {} script_params = defaultdict(dict) for executor_name, da in docs_map.items(): for doc in da: if doc.id not in docs: docs[doc.id] = doc docs[doc.id].tags['embeddings'] = {} if doc.id not in queries: queries[doc.id] = get_default_query( doc, score_calculation, filter, ) pinned_queries[doc.id] = get_pinned_query( doc, query_to_curated_ids, ) if any( _matching_method == 'bm25' for (_, _, _matching_method, _) in score_calculation ): sources[doc.id] = '1.0 + _score / (_score + 10.0)' else: sources[doc.id] = '1.0' for ( query_field, document_field, matching_method, linear_weight, ) in get_scores(executor_name, score_calculation): field_doc = get_chunk_by_field_name(doc, query_field) if get_score_breakdown: docs[doc.id].tags['embeddings'][ f'{query_field}-{matching_method}' ] = field_doc.embedding query_string = f'params.query_{query_field}_{executor_name}' document_string = f'{document_field}-{matching_method}' sources[ doc.id ] += f" + {float(linear_weight)}*{metrics_mapping[metric]}({query_string}, '{document_string}.embedding')" script_params[doc.id][ f'query_{query_field}_{executor_name}' ] = field_doc.embedding es_queries = [] for doc_id, query in queries.items(): script_score = { 'script_score': { 'query': { 'bool': query['bool'], }, 'script': { 'source': sources[doc_id], 'params': script_params[doc_id], }, }, } if pinned_queries[doc_id]: query_json = {'pinned': pinned_queries[doc_id]['pinned']} query_json['pinned']['organic'] = script_score else: query_json = script_score es_queries.append((docs[doc_id], query_json)) return es_queries
[docs]def get_default_query( doc: Document, score_calculation: List[Tuple], filter: Dict = {}, ): query = { 'bool': { 'should': [ {'match_all': {}}, ], }, } # build bm25 part for (query_field, index_field, matching_method, linear_weight) in score_calculation: if matching_method == 'bm25': text = get_chunk_by_field_name(doc, query_field).text query['bool']['should'].append( { 'multi_match': { 'query': text, 'fields': [f"{index_field}^{linear_weight}"], } } ) # add filter if filter: es_search_filter = process_filter(filter) query['bool']['filter'] = es_search_filter return query
[docs]def get_pinned_query(doc: Document, query_to_curated_ids: Dict[str, list] = {}) -> Dict: pinned_query = {} if getattr(doc, 'query_text', None): query_text = doc.query_text.text if query_text in query_to_curated_ids.keys(): pinned_query = {'pinned': {'ids': query_to_curated_ids[query_text]}} return pinned_query
[docs]def process_filter( filter: Dict[str, Union[List[str], Dict[str, float]]] ) -> List[Dict[str, Any]]: es_search_filters = [] for field, filters in filter.items(): field = field.replace('__', '.', 1) es_search_filter = {} if isinstance(filters, list): # must be categorical (list of terms) es_search_filter['terms'] = {field: filters} elif isinstance(filters, dict): # must be numerical (range with operators) es_search_filter['range'] = {field: filters} else: raise ValueError( f'Filter {field}: {filters} is not a list of terms or a dictionary of ranges' ) es_search_filters.append(es_search_filter) return es_search_filters
[docs]def get_scores(encoder, score_calculation): for ( query_field, document_field, _encoder, linear_weight, ) in score_calculation: if encoder == _encoder: yield query_field, document_field, _encoder, linear_weight