import mimetypes
import os
import time
import warnings
from typing import (
overload,
TYPE_CHECKING,
Optional,
Union,
Iterator,
Generator,
Iterable,
Dict,
)
from urllib.parse import urlparse
from functools import partial
from docarray import DocumentArray
if TYPE_CHECKING:
import numpy as np
from docarray import Document
from jina.clients.base import CallbackFnType
[docs]class Client:
def __init__(self, server: str, credential: dict = {}, **kwargs):
"""Create a Clip client object that connects to the Clip server.
Server scheme is in the format of ``scheme://netloc:port``, where
- scheme: one of grpc, websocket, http, grpcs, websockets, https
- netloc: the server ip address or hostname
- port: the public port of the server
:param server: the server URI
:param credential: the credential for authentication ``{'Authentication': '<token>'}``
"""
try:
r = urlparse(server)
_port = r.port
self._scheme = r.scheme
except:
raise ValueError(f'{server} is not a valid scheme')
_tls = False
if self._scheme in ('grpcs', 'https', 'wss'):
self._scheme = self._scheme[:-1]
_tls = True
if self._scheme == 'ws':
self._scheme = 'websocket' # temp fix for the core
if credential:
warnings.warn(
'Credential is not supported for websocket, please use grpc or http'
)
if self._scheme in ('grpc', 'http', 'websocket'):
_kwargs = dict(host=r.hostname, port=_port, protocol=self._scheme, tls=_tls)
from jina import Client
self._client = Client(**_kwargs)
self._async_client = Client(**_kwargs, asyncio=True)
else:
raise ValueError(f'{server} is not a valid scheme')
self._authorization = credential.get(
'Authorization', os.environ.get('CLIP_AUTH_TOKEN')
)
[docs] def profile(self, content: Optional[str] = '') -> Dict[str, float]:
"""Profiling a single query's roundtrip including network and computation latency. Results is summarized in a table.
:param content: the content to be sent for profiling. By default it sends an empty Document
that helps you understand the network latency.
:return: the latency report in a dict.
"""
st = time.perf_counter()
r = self._client.post(
'/', self._iter_doc([content], DocumentArray()), return_responses=True
)
ed = (time.perf_counter() - st) * 1000
route = r[0].routes
gateway_time = (
route[0].end_time.ToMilliseconds() - route[0].start_time.ToMilliseconds()
)
clip_time = (
route[1].end_time.ToMilliseconds() - route[1].start_time.ToMilliseconds()
)
network_time = ed - gateway_time
server_network = gateway_time - clip_time
from rich.table import Table
def make_table(_title, _time, _percent):
table = Table(show_header=False, box=None)
table.add_row(
_title, f'[b]{_time:.0f}[/b]ms', f'[dim]{_percent * 100:.0f}%[/dim]'
)
return table
from rich.tree import Tree
t = Tree(make_table('Roundtrip', ed, 1))
t.add(make_table('Client-server network', network_time, network_time / ed))
t2 = t.add(make_table('Server', gateway_time, gateway_time / ed))
t2.add(
make_table(
'Gateway-CLIP network', server_network, server_network / gateway_time
)
)
t2.add(make_table('CLIP model', clip_time, clip_time / gateway_time))
from rich import print
print(t)
return {
'Roundtrip': ed,
'Client-server network': network_time,
'Server': gateway_time,
'Gateway-CLIP network': server_network,
'CLIP model': clip_time,
}
def _update_pbar(self, response, func: Optional['CallbackFnType'] = None):
from rich import filesize
r = response.data.docs
if not self._pbar._tasks[self._r_task].started:
self._pbar.start_task(self._r_task)
self._pbar.update(
self._r_task,
advance=len(r),
total_size=str(
filesize.decimal(int(os.environ.get('JINA_GRPC_RECV_BYTES', '0')))
),
)
if func is not None:
func(response)
def _prepare_streaming(self, disable, total):
if total is None:
total = 500
warnings.warn(
'The length of the input is unknown, the progressbar would not be accurate.'
)
elif total > 500:
warnings.warn(
'Please ensure all the inputs are valid, otherwise the request will be aborted.'
)
from docarray.array.mixins.io.pbar import get_pbar
self._pbar = get_pbar(disable)
os.environ['JINA_GRPC_SEND_BYTES'] = '0'
os.environ['JINA_GRPC_RECV_BYTES'] = '0'
self._r_task = self._pbar.add_task(
':arrow_down: Progress', total=total, total_size=0, start=False
)
@staticmethod
def _gather_result(
response, results: 'DocumentArray', attribute: Optional[str] = None
):
r = response.data.docs
if attribute:
results[r[:, 'id']][:, attribute] = r[:, attribute]
def _iter_doc(
self, content, results: Optional['DocumentArray'] = None
) -> Generator['Document', None, None]:
from docarray import Document
for c in content:
if isinstance(c, str):
_mime = mimetypes.guess_type(c)[0]
if _mime and _mime.startswith('image'):
d = Document(
uri=c,
).load_uri_to_blob()
else:
d = Document(text=c)
elif isinstance(c, Document):
if c.content_type in ('text', 'blob'):
d = c
elif not c.blob and c.uri:
c.load_uri_to_blob()
d = c
elif c.tensor is not None:
d = c
else:
raise TypeError(f'unsupported input type {c!r} {c.content_type}')
else:
raise TypeError(f'unsupported input type {c!r}')
if results is not None:
results.append(d)
yield d
def _get_post_payload(
self, content, results: Optional['DocumentArray'] = None, **kwargs
):
payload = dict(
inputs=self._iter_doc(content, results),
request_size=kwargs.get('batch_size', 8),
total_docs=len(content) if hasattr(content, '__len__') else None,
)
if self._scheme == 'grpc' and self._authorization:
payload.update(metadata=(('authorization', self._authorization),))
elif self._scheme == 'http' and self._authorization:
payload.update(headers={'Authorization': self._authorization})
return payload
@staticmethod
def _unboxed_result(results: Optional['DocumentArray'] = None, unbox: bool = False):
if results is not None:
if results.embeddings is None:
raise ValueError(
'Empty embedding returned from the server. '
'This often due to a mis-config of the server, '
'restarting the server or changing the serving port number often solves the problem'
)
return results.embeddings if unbox else results
@overload
def encode(
self,
content: Iterable[str],
*,
batch_size: Optional[int] = None,
show_progress: bool = False,
parameters: Optional[dict] = None,
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 100,
) -> 'np.ndarray':
"""Encode images and texts into embeddings where the input is an iterable of raw strings.
Each image and text must be represented as a string. The following strings are acceptable:
- local image filepath, will be considered as an image
- remote image http/https, will be considered as an image
- a dataURI, will be considered as an image
- plain text, will be considered as a sentence
:param content: an iterator of image URIs or sentences, each element is an image or a text sentence as a string.
:param batch_size: the number of elements in each request when sending ``content``
:param show_progress: if set, show a progress bar
:param parameters: the parameters for the encoding, you can specify the model to use when you have multiple models
:param on_done: the callback function executed while streaming, after successful completion of each request.
It takes the response ``DataRequest`` as the only argument
:param on_error: the callback function executed while streaming, after failed completion of each request.
It takes the response ``DataRequest`` as the only argument
:param on_always: the callback function executed while streaming, after completion of each request.
It takes the response ``DataRequest`` as the only argument
:param prefetch: the number of in-flight batches made by the post() method. Use a lower value for expensive
operations, and a higher value for faster response times
:return: the embedding in a numpy ndarray with shape ``[N, D]``. ``N`` is in the same length of ``content``
"""
...
@overload
def encode(
self,
content: Union['DocumentArray', Iterable['Document']],
*,
batch_size: Optional[int] = None,
show_progress: bool = False,
parameters: Optional[dict] = None,
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 100,
) -> 'DocumentArray':
"""Encode images and texts into embeddings where the input is an iterable of :class:`docarray.Document`.
:param content: an iterable of :class:`docarray.Document`, each Document must be filled with `.uri`, `.text` or `.blob`.
:param batch_size: the number of elements in each request when sending ``content``
:param show_progress: if set, show a progress bar
:param parameters: the parameters for the encoding, you can specify the model to use when you have multiple models
:param on_done: the callback function executed while streaming, after successful completion of each request.
It takes the response ``DataRequest`` as the only argument
:param on_error: the callback function executed while streaming, after failed completion of each request.
It takes the response ``DataRequest`` as the only argument
:param on_always: the callback function executed while streaming, after completion of each request.
It takes the response ``DataRequest`` as the only argument
:param prefetch: the number of in-flight batches made by the post() method. Use a lower value for expensive
operations, and a higher value for faster response times
:return: the embedding in a numpy ndarray with shape ``[N, D]``. ``N`` is in the same length of ``content``
"""
...
[docs] def encode(self, content, **kwargs):
if isinstance(content, str):
raise TypeError(
f'Content must be an Iterable of [str, Document], try `.encode(["{content}"])` instead'
)
if hasattr(content, '__len__') and len(content) == 0:
return DocumentArray() if isinstance(content, DocumentArray) else []
self._prepare_streaming(
not kwargs.get('show_progress'),
total=len(content) if hasattr(content, '__len__') else None,
)
on_done = kwargs.pop('on_done', None)
on_error = kwargs.pop('on_error', None)
on_always = kwargs.pop('on_always', None)
prefetch = kwargs.pop('prefetch', 100)
results = DocumentArray() if not on_done and not on_always else None
if not on_done:
on_done = partial(
self._gather_result, results=results, attribute='embedding'
)
with self._pbar:
parameters = kwargs.pop('parameters', {})
parameters['drop_image_content'] = parameters.get(
'drop_image_content', True
)
model_name = parameters.pop('model_name', '') if parameters else ''
self._client.post(
on=f'/encode/{model_name}'.rstrip('/'),
**self._get_post_payload(content, results, **kwargs),
on_done=on_done,
on_error=on_error,
on_always=partial(self._update_pbar, func=on_always),
parameters=parameters,
prefetch=prefetch,
)
unbox = hasattr(content, '__len__') and isinstance(content[0], str)
return self._unboxed_result(results, unbox)
@overload
async def aencode(
self,
content: Iterator[str],
*,
batch_size: Optional[int] = None,
show_progress: bool = False,
parameters: Optional[dict] = None,
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 100,
) -> 'np.ndarray':
...
@overload
async def aencode(
self,
content: Union['DocumentArray', Iterable['Document']],
*,
batch_size: Optional[int] = None,
show_progress: bool = False,
parameters: Optional[dict] = None,
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 100,
) -> 'DocumentArray':
...
[docs] async def aencode(self, content, **kwargs):
if isinstance(content, str):
raise TypeError(
f'Content must be an Iterable of [str, Document], try `.aencode(["{content}"])` instead'
)
if hasattr(content, '__len__') and len(content) == 0:
return DocumentArray() if isinstance(content, DocumentArray) else []
self._prepare_streaming(
not kwargs.get('show_progress'),
total=len(content) if hasattr(content, '__len__') else None,
)
on_done = kwargs.pop('on_done', None)
on_error = kwargs.pop('on_error', None)
on_always = kwargs.pop('on_always', None)
prefetch = kwargs.pop('prefetch', 100)
results = DocumentArray() if not on_done and not on_always else None
if not on_done:
on_done = partial(
self._gather_result, results=results, attribute='embedding'
)
with self._pbar:
parameters = kwargs.pop('parameters', {})
parameters['drop_image_content'] = parameters.get(
'drop_image_content', True
)
model_name = parameters.get('model_name', '') if parameters else ''
async for _ in self._async_client.post(
on=f'/encode/{model_name}'.rstrip('/'),
**self._get_post_payload(content, results, **kwargs),
on_done=on_done,
on_error=on_error,
on_always=partial(self._update_pbar, func=on_always),
parameters=parameters,
prefetch=prefetch,
):
continue
unbox = hasattr(content, '__len__') and isinstance(content[0], str)
return self._unboxed_result(results, unbox)
def _iter_rank_docs(
self, content, results: Optional['DocumentArray'] = None, source='matches'
) -> Generator['Document', None, None]:
from docarray import Document
for c in content:
if isinstance(c, Document):
d = self._prepare_rank_doc(c, source)
else:
raise TypeError(f'Unsupported input type {c!r}')
if results is not None:
results.append(d)
yield d
def _get_rank_payload(
self, content, results: Optional['DocumentArray'] = None, **kwargs
):
payload = dict(
inputs=self._iter_rank_docs(
content, results, source=kwargs.get('source', 'matches')
),
request_size=kwargs.get('batch_size', 8),
total_docs=len(content) if hasattr(content, '__len__') else None,
)
if self._scheme == 'grpc' and self._authorization:
payload.update(metadata=(('authorization', self._authorization),))
elif self._scheme == 'http' and self._authorization:
payload.update(headers={'Authorization': self._authorization})
return payload
@staticmethod
def _prepare_single_doc(d: 'Document'):
if d.content_type in ('text', 'blob'):
return d
elif not d.blob and d.uri:
d.load_uri_to_blob()
return d
elif d.tensor is not None:
return d
else:
raise TypeError(f'Unsupported input type {d!r} {d.content_type}')
@staticmethod
def _prepare_rank_doc(d: 'Document', _source: str = 'matches'):
_get = lambda d: getattr(d, _source)
if not _get(d):
raise ValueError(f'`.rank()` requires every doc to have `.{_source}`')
d = Client._prepare_single_doc(d)
setattr(d, _source, [Client._prepare_single_doc(c) for c in _get(d)])
return d
[docs] def rank(
self, docs: Union['DocumentArray', Iterable['Document']], **kwargs
) -> 'DocumentArray':
"""Rank image-text matches according to the server CLIP model.
Given a Document with nested matches, where the root is image/text and the matches is in another modality, i.e.
text/image; this method ranks the matches according to the CLIP model.
Each match now has a new score inside ``clip_score`` and matches are sorted descendingly according to this score.
More details can be found in: https://github.com/openai/CLIP#usage
:param docs: the input Documents
:return: the ranked Documents in a DocumentArray.
"""
if isinstance(docs, str):
raise TypeError(f'Content must be an Iterable of [Document]')
self._prepare_streaming(
not kwargs.get('show_progress'),
total=len(docs) if hasattr(docs, '__len__') else None,
)
on_done = kwargs.pop('on_done', None)
on_error = kwargs.pop('on_error', None)
on_always = kwargs.pop('on_always', None)
prefetch = kwargs.pop('prefetch', 100)
results = DocumentArray() if not on_done and not on_always else None
if not on_done:
on_done = partial(self._gather_result, results=results, attribute='matches')
with self._pbar:
parameters = kwargs.pop('parameters', {})
parameters['drop_image_content'] = parameters.get(
'drop_image_content', True
)
model_name = parameters.get('model_name', '') if parameters else ''
self._client.post(
on=f'/rank/{model_name}'.rstrip('/'),
**self._get_rank_payload(docs, results, **kwargs),
on_done=on_done,
on_error=on_error,
on_always=partial(self._update_pbar, func=on_always),
parameters=parameters,
prefetch=prefetch,
)
return results
[docs] async def arank(
self, docs: Union['DocumentArray', Iterable['Document']], **kwargs
) -> 'DocumentArray':
if isinstance(docs, str):
raise TypeError(f'Content must be an Iterable of [Document]')
self._prepare_streaming(
not kwargs.get('show_progress'),
total=len(docs) if hasattr(docs, '__len__') else None,
)
on_done = kwargs.pop('on_done', None)
on_error = kwargs.pop('on_error', None)
on_always = kwargs.pop('on_always', None)
prefetch = kwargs.pop('prefetch', 100)
results = DocumentArray() if not on_done and not on_always else None
if not on_done:
on_done = partial(self._gather_result, results=results, attribute='matches')
with self._pbar:
parameters = kwargs.pop('parameters', {})
parameters['drop_image_content'] = parameters.get(
'drop_image_content', True
)
model_name = parameters.get('model_name', '') if parameters else ''
async for _ in self._async_client.post(
on=f'/rank/{model_name}'.rstrip('/'),
**self._get_rank_payload(docs, results, **kwargs),
on_done=on_done,
on_error=on_error,
on_always=partial(self._update_pbar, func=on_always),
parameters=parameters,
prefetch=prefetch,
):
continue
return results
@overload
def index(
self,
content: Iterable[str],
*,
batch_size: Optional[int] = None,
show_progress: bool = False,
parameters: Optional[Dict] = None,
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 100,
):
"""Index the images or texts where their embeddings are computed by the server CLIP model.
Each image and text must be represented as a string. The following strings are acceptable:
- local image filepath, will be considered as an image
- remote image http/https, will be considered as an image
- a dataURI, will be considered as an image
- plain text, will be considered as a sentence
:param content: an iterator of image URIs or sentences, each element is an image or a text sentence as a string.
:param batch_size: the number of elements in each request when sending ``content``
:param show_progress: if set, show a progress bar
:param parameters: the parameters for the indexing, you can specify the model to use when you have multiple models
:param on_done: the callback function executed while streaming, after successful completion of each request.
It takes the response ``DataRequest`` as the only argument
:param on_error: the callback function executed while streaming, after an error occurs in each request.
It takes the response ``DataRequest`` as the only argument
:param on_always: the callback function executed while streaming, after each request is completed.
It takes the response ``DataRequest`` as the only argument
:param prefetch: the number of in-flight batches made by the post() method. Use a lower value for expensive
operations, and a higher value for faster response times
:return: the embedding in a numpy ndarray with shape ``[N, D]``. ``N`` is in the same length of ``content``
"""
...
@overload
def index(
self,
content: Union['DocumentArray', Iterable['Document']],
*,
batch_size: Optional[int] = None,
show_progress: bool = False,
parameters: Optional[dict] = None,
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 100,
) -> 'DocumentArray':
"""Index the images or texts where their embeddings are computed by the server CLIP model.
:param content: an iterable of :class:`docarray.Document`, each Document must be filled with `.uri`, `.text` or `.blob`.
:param batch_size: the number of elements in each request when sending ``content``
:param show_progress: if set, show a progress bar
:param parameters: the parameters for the indexing, you can specify the model to use when you have multiple models
:param on_done: the callback function executed while streaming, after successful completion of each request.
It takes the response ``DataRequest`` as the only argument
:param on_error: the callback function executed while streaming, after an error occurs in each request.
It takes the response ``DataRequest`` as the only argument
:param on_always: the callback function executed while streaming, after each request is completed.
It takes the response ``DataRequest`` as the only argument
:param prefetch: the number of in-flight batches made by the post() method. Use a lower value for expensive
operations, and a higher value for faster response times
:return: the embedding in a numpy ndarray with shape ``[N, D]``. ``N`` is in the same length of ``content``
"""
...
[docs] def index(self, content, **kwargs):
if isinstance(content, str):
raise TypeError(
f'content must be an Iterable of [str, Document], try `.index(["{content}"])` instead'
)
self._prepare_streaming(
not kwargs.get('show_progress'),
total=len(content) if hasattr(content, '__len__') else None,
)
on_done = kwargs.pop('on_done', None)
on_error = kwargs.pop('on_error', None)
on_always = kwargs.pop('on_always', None)
prefetch = kwargs.pop('prefetch', 100)
results = DocumentArray() if not on_done and not on_always else None
if not on_done:
on_done = partial(
self._gather_result, results=results, attribute='embedding'
)
with self._pbar:
parameters = kwargs.pop('parameters', {})
parameters['drop_image_content'] = parameters.get(
'drop_image_content', True
)
self._client.post(
on='/index',
**self._get_post_payload(content, results, **kwargs),
on_done=on_done,
on_error=on_error,
on_always=partial(self._update_pbar, func=on_always),
parameters=parameters,
prefetch=prefetch,
)
return results
@overload
async def aindex(
self,
content: Iterator[str],
*,
batch_size: Optional[int] = None,
show_progress: bool = False,
parameters: Optional[Dict] = None,
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 100,
):
...
@overload
async def aindex(
self,
content: Union['DocumentArray', Iterable['Document']],
*,
batch_size: Optional[int] = None,
show_progress: bool = False,
parameters: Optional[dict] = None,
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 100,
):
...
[docs] async def aindex(self, content, **kwargs):
if isinstance(content, str):
raise TypeError(
f'content must be an Iterable of [str, Document], try `.aindex(["{content}"])` instead'
)
self._prepare_streaming(
not kwargs.get('show_progress'),
total=len(content) if hasattr(content, '__len__') else None,
)
on_done = kwargs.pop('on_done', None)
on_error = kwargs.pop('on_error', None)
on_always = kwargs.pop('on_always', None)
prefetch = kwargs.pop('prefetch', 100)
results = DocumentArray() if not on_done and not on_always else None
if not on_done:
on_done = partial(
self._gather_result, results=results, attribute='embedding'
)
with self._pbar:
parameters = kwargs.pop('parameters', {})
parameters['drop_image_content'] = parameters.get(
'drop_image_content', True
)
async for _ in self._async_client.post(
on='/index',
**self._get_post_payload(content, results, **kwargs),
on_done=on_done,
on_error=on_error,
on_always=partial(self._update_pbar, func=on_always),
parameters=parameters,
prefetch=prefetch,
):
continue
return results
@overload
def search(
self,
content: Iterable[str],
*,
limit: int = 10,
batch_size: Optional[int] = None,
show_progress: bool = False,
parameters: Optional[Dict] = None,
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 100,
) -> 'DocumentArray':
"""Search for top k results for given query string or ``Document``.
If the input is a string, will use this string as query. If the input is a ``Document``,
will use this ``Document`` as query.
:param content: list of queries.
:param limit: the number of results to return.
:param batch_size: the number of elements in each request when sending ``content``.
:param show_progress: if set, show a progress bar.
:param parameters: parameters passed to search function.
:param on_done: the callback function executed while streaming, after successful completion of each request.
It takes the response ``DataRequest`` as the only argument
:param on_error: the callback function executed while streaming, after an error occurs in each request.
It takes the response ``DataRequest`` as the only argument
:param on_always: the callback function executed while streaming, after each request is completed.
It takes the response ``DataRequest`` as the only argument
:param prefetch: the number of in-flight batches made by the post() method. Use a lower value for expensive
operations, and a higher value for faster response times
"""
...
@overload
def search(
self,
content: Union['DocumentArray', Iterable['Document']],
*,
limit: int = 10,
batch_size: Optional[int] = None,
show_progress: bool = False,
parameters: Optional[dict] = None,
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 100,
) -> 'DocumentArray':
"""Search for top k results for given query string or ``Document``.
If the input is a string, will use this string as query. If the input is a ``Document``,
will use this ``Document`` as query.
:param content: list of queries.
:param limit: the number of results to return.
:param batch_size: the number of elements in each request when sending ``content``.
:param show_progress: if set, show a progress bar.
:param parameters: parameters passed to search function.
:param on_done: the callback function executed while streaming, after successful completion of each request.
It takes the response ``DataRequest`` as the only argument
:param on_error: the callback function executed while streaming, after an error occurs in each request.
It takes the response ``DataRequest`` as the only argument
:param on_always: the callback function executed while streaming, after each request is completed.
It takes the response ``DataRequest`` as the only argument
:param prefetch: the number of in-flight batches made by the post() method. Use a lower value for expensive
operations, and a higher value for faster response times
"""
...
[docs] def search(self, content, limit: int = 10, **kwargs) -> 'DocumentArray':
if isinstance(content, str):
raise TypeError(
f'content must be an Iterable of [str, Document], try `.search(["{content}"])` instead'
)
self._prepare_streaming(
not kwargs.get('show_progress'),
total=len(content) if hasattr(content, '__len__') else None,
)
on_done = kwargs.pop('on_done', None)
on_error = kwargs.pop('on_error', None)
on_always = kwargs.pop('on_always', None)
prefetch = kwargs.pop('prefetch', 100)
results = DocumentArray() if not on_done and not on_always else None
if not on_done:
on_done = partial(self._gather_result, results=results, attribute='matches')
with self._pbar:
parameters = kwargs.pop('parameters', {})
parameters['limit'] = limit
parameters['drop_image_content'] = parameters.get(
'drop_image_content', True
)
self._client.post(
on='/search',
**self._get_post_payload(content, results, **kwargs),
on_done=on_done,
on_error=on_error,
on_always=partial(self._update_pbar, func=on_always),
parameters=parameters,
prefetch=prefetch,
)
return results
@overload
async def asearch(
self,
content: Iterator[str],
*,
limit: int = 10,
batch_size: Optional[int] = None,
show_progress: bool = False,
parameters: Optional[Dict] = None,
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 100,
):
...
@overload
async def asearch(
self,
content: Union['DocumentArray', Iterable['Document']],
*,
limit: int = 10,
batch_size: Optional[int] = None,
show_progress: bool = False,
parameters: Optional[dict] = None,
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 100,
):
...
[docs] async def asearch(self, content, limit: int = 10, **kwargs):
if isinstance(content, str):
raise TypeError(
f'content must be an Iterable of [str, Document], try `.asearch(["{content}"])` instead'
)
self._prepare_streaming(
not kwargs.get('show_progress'),
total=len(content) if hasattr(content, '__len__') else None,
)
on_done = kwargs.pop('on_done', None)
on_error = kwargs.pop('on_error', None)
on_always = kwargs.pop('on_always', None)
prefetch = kwargs.pop('prefetch', 100)
results = DocumentArray() if not on_done and not on_always else None
if not on_done:
on_done = partial(self._gather_result, results=results, attribute='matches')
with self._pbar:
parameters = kwargs.pop('parameters', {})
parameters['limit'] = limit
parameters['drop_image_content'] = parameters.get(
'drop_image_content', True
)
async for _ in self._async_client.post(
on='/search',
**self._get_post_payload(content, results, **kwargs),
on_done=on_done,
on_error=on_error,
on_always=partial(self._update_pbar, func=on_always),
parameters=parameters,
prefetch=prefetch,
):
continue
return results