colbert.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. import os
  2. import logging
  3. import torch
  4. import numpy as np
  5. from colbert.infra import ColBERTConfig
  6. from colbert.modeling.checkpoint import Checkpoint
  7. from open_webui.env import SRC_LOG_LEVELS
  8. from open_webui.retrieval.models.base_reranker import BaseReranker
  9. log = logging.getLogger(__name__)
  10. log.setLevel(SRC_LOG_LEVELS["RAG"])
  11. class ColBERT(BaseReranker):
  12. def __init__(self, name, **kwargs) -> None:
  13. log.info("ColBERT: Loading model", name)
  14. self.device = "cuda" if torch.cuda.is_available() else "cpu"
  15. DOCKER = kwargs.get("env") == "docker"
  16. if DOCKER:
  17. # This is a workaround for the issue with the docker container
  18. # where the torch extension is not loaded properly
  19. # and the following error is thrown:
  20. # /root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/segmented_maxsim_cpp.so: cannot open shared object file: No such file or directory
  21. lock_file = (
  22. "/root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/lock"
  23. )
  24. if os.path.exists(lock_file):
  25. os.remove(lock_file)
  26. self.ckpt = Checkpoint(
  27. name,
  28. colbert_config=ColBERTConfig(model_name=name),
  29. ).to(self.device)
  30. pass
  31. def calculate_similarity_scores(self, query_embeddings, document_embeddings):
  32. query_embeddings = query_embeddings.to(self.device)
  33. document_embeddings = document_embeddings.to(self.device)
  34. # Validate dimensions to ensure compatibility
  35. if query_embeddings.dim() != 3:
  36. raise ValueError(
  37. f"Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}."
  38. )
  39. if document_embeddings.dim() != 3:
  40. raise ValueError(
  41. f"Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}."
  42. )
  43. if query_embeddings.size(0) not in [1, document_embeddings.size(0)]:
  44. raise ValueError(
  45. "There should be either one query or queries equal to the number of documents."
  46. )
  47. # Transpose the query embeddings to align for matrix multiplication
  48. transposed_query_embeddings = query_embeddings.permute(0, 2, 1)
  49. # Compute similarity scores using batch matrix multiplication
  50. computed_scores = torch.matmul(document_embeddings, transposed_query_embeddings)
  51. # Apply max pooling to extract the highest semantic similarity across each document's sequence
  52. maximum_scores = torch.max(computed_scores, dim=1).values
  53. # Sum up the maximum scores across features to get the overall document relevance scores
  54. final_scores = maximum_scores.sum(dim=1)
  55. normalized_scores = torch.softmax(final_scores, dim=0)
  56. return normalized_scores.detach().cpu().numpy().astype(np.float32)
  57. def predict(self, sentences):
  58. query = sentences[0][0]
  59. docs = [i[1] for i in sentences]
  60. # Embedding the documents
  61. embedded_docs = self.ckpt.docFromText(docs, bsize=32)[0]
  62. # Embedding the queries
  63. embedded_queries = self.ckpt.queryFromText([query], bsize=32)
  64. embedded_query = embedded_queries[0]
  65. # Calculate retrieval scores for the query against all documents
  66. scores = self.calculate_similarity_scores(
  67. embedded_query.unsqueeze(0), embedded_docs
  68. )
  69. return scores