diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 752c25dd3..7712446e7 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1040,7 +1040,8 @@ def embed( # get pooling information pooling_type = self.pooling_type() - logits_all = pooling_type == llama_cpp.LLAMA_POOLING_TYPE_NONE + # All tokens need outputs for embeddings; llama.cpp otherwise logs an "overriding" warning per input. + logits_all = True if self.context_params.embeddings is False: raise RuntimeError(