Source code for clip_client.client

import mimetypes
import os
import time
import warnings
from typing import (
    overload,
    TYPE_CHECKING,
    Optional,
    Union,
    Iterator,
    Generator,
    Iterable,
    Dict,
)
from urllib.parse import urlparse
from functools import partial
from docarray import DocumentArray

if TYPE_CHECKING:
    import numpy as np
    from docarray import Document


[docs]class Client: def __init__(self, server: str): """Create a Clip client object that connects to the Clip server. Server scheme is in the format of `scheme://netloc:port`, where - scheme: one of grpc, websocket, http, grpcs, websockets, https - netloc: the server ip address or hostname - port: the public port of the server :param server: the server URI """ try: r = urlparse(server) _port = r.port _scheme = r.scheme if not _scheme: raise except: raise ValueError(f'{server} is not a valid scheme') _tls = False if _scheme in ('grpcs', 'https', 'wss'): _scheme = _scheme[:-1] _tls = True if _scheme == 'ws': _scheme = 'websocket' # temp fix for the core if _scheme in ('grpc', 'http', 'websocket'): _kwargs = dict(host=r.hostname, port=_port, protocol=_scheme, tls=_tls) from jina import Client self._client = Client(**_kwargs) self._async_client = Client(**_kwargs, asyncio=True) else: raise ValueError(f'{server} is not a valid scheme') @overload def encode( self, content: Iterable[str], *, batch_size: Optional[int] = None, show_progress: bool = False, ) -> 'np.ndarray': """Encode images and texts into embeddings where the input is an iterable of raw strings. Each image and text must be represented as a string. The following strings are acceptable: - local image filepath, will be considered as an image - remote image http/https, will be considered as an image - a dataURI, will be considered as an image - plain text, will be considered as a sentence :param content: an iterator of image URIs or sentences, each element is an image or a text sentence as a string. :param batch_size: the number of elements in each request when sending ``content`` :param show_progress: if set, show a progress bar :return: the embedding in a numpy ndarray with shape ``[N, D]``. ``N`` is in the same length of ``content`` """ ... @overload def encode( self, content: Union['DocumentArray', Iterable['Document']], *, batch_size: Optional[int] = None, show_progress: bool = False, ) -> 'DocumentArray': """Encode images and texts into embeddings where the input is an iterable of :class:`docarray.Document`. :param content: an iterable of :class:`docarray.Document`, each Document must be filled with `.uri`, `.text` or `.blob`. :param batch_size: the number of elements in each request when sending ``content`` :param show_progress: if set, show a progress bar :return: the embedding in a numpy ndarray with shape ``[N, D]``. ``N`` is in the same length of ``content`` """ ...
[docs] def encode(self, content, **kwargs): if isinstance(content, str): raise TypeError( f'content must be an Iterable of [str, Document], try `.encode(["{content}"])` instead' ) self._prepare_streaming( not kwargs.get('show_progress'), total=len(content) if hasattr(content, '__len__') else None, ) results = DocumentArray() with self._pbar: self._client.post( **self._get_post_payload(content, kwargs), on_done=partial(self._gather_result, results=results), ) return self._unboxed_result(results)
def _gather_result(self, response, results: 'DocumentArray'): from rich import filesize if not results: self._pbar.start_task(self._r_task) r = response.data.docs results.extend(r) self._pbar.update( self._r_task, advance=len(r), total_size=str( filesize.decimal(int(os.environ.get('JINA_GRPC_RECV_BYTES', '0'))) ), ) @staticmethod def _unboxed_result(results: 'DocumentArray'): if results.embeddings is None: raise ValueError( 'empty embedding returned from the server. ' 'This often due to a mis-config of the server, ' 'restarting the server or changing the serving port number often solves the problem' ) return ( results.embeddings if ('__created_by_CAS__' in results[0].tags) else results ) def _iter_doc(self, content) -> Generator['Document', None, None]: from rich import filesize from docarray import Document if hasattr(self, '_pbar'): self._pbar.start_task(self._s_task) for c in content: if isinstance(c, str): _mime = mimetypes.guess_type(c)[0] if _mime and _mime.startswith('image'): yield Document( tags={'__created_by_CAS__': True}, uri=c ).load_uri_to_blob() else: yield Document(tags={'__created_by_CAS__': True}, text=c) elif isinstance(c, Document): if c.content_type in ('text', 'blob'): yield c elif not c.blob and c.uri: c.load_uri_to_blob() yield c elif c.tensor is not None: yield c else: raise TypeError(f'unsupported input type {c!r} {c.content_type}') else: raise TypeError(f'unsupported input type {c!r}') if hasattr(self, '_pbar'): self._pbar.update( self._s_task, advance=1, total_size=str( filesize.decimal( int(os.environ.get('JINA_GRPC_SEND_BYTES', '0')) ) ), ) def _get_post_payload(self, content, kwargs): return dict( on='/', inputs=self._iter_doc(content), request_size=kwargs.get('batch_size', 8), total_docs=len(content) if hasattr(content, '__len__') else None, )
[docs] def profile(self, content: Optional[str] = '') -> Dict[str, float]: """Profiling a single query's roundtrip including network and computation latency. Results is summarized in a table. :param content: the content to be sent for profiling. By default it sends an empty Document that helps you understand the network latency. :return: the latency report in a dict. """ st = time.perf_counter() r = self._client.post('/', self._iter_doc([content]), return_responses=True) ed = (time.perf_counter() - st) * 1000 route = r[0].routes gateway_time = ( route[0].end_time.ToMilliseconds() - route[0].start_time.ToMilliseconds() ) clip_time = ( route[1].end_time.ToMilliseconds() - route[1].start_time.ToMilliseconds() ) network_time = ed - gateway_time server_network = gateway_time - clip_time from rich.table import Table def make_table(_title, _time, _percent): table = Table(show_header=False, box=None) table.add_row( _title, f'[b]{_time:.0f}[/b]ms', f'[dim]{_percent * 100:.0f}%[/dim]' ) return table from rich.tree import Tree t = Tree(make_table('Roundtrip', ed, 1)) t.add(make_table('Client-server network', network_time, network_time / ed)) t2 = t.add(make_table('Server', gateway_time, gateway_time / ed)) t2.add( make_table( 'Gateway-CLIP network', server_network, server_network / gateway_time ) ) t2.add(make_table('CLIP model', clip_time, clip_time / gateway_time)) from rich import print print(t) return { 'Roundtrip': ed, 'Client-server network': network_time, 'Server': gateway_time, 'Gateway-CLIP network': server_network, 'CLIP model': clip_time, }
@overload async def aencode( self, content: Iterator[str], *, batch_size: Optional[int] = None, show_progress: bool = False, ) -> 'np.ndarray': ... @overload async def aencode( self, content: Union['DocumentArray', Iterable['Document']], *, batch_size: Optional[int] = None, show_progress: bool = False, ) -> 'DocumentArray': ...
[docs] async def aencode(self, content, **kwargs): from rich import filesize self._prepare_streaming( not kwargs.get('show_progress'), total=len(content) if hasattr(content, '__len__') else None, ) results = DocumentArray() async for da in self._async_client.post( **self._get_post_payload(content, kwargs) ): if not results: self._pbar.start_task(self._r_task) results.extend(da) self._pbar.update( self._r_task, advance=len(da), total_size=str( filesize.decimal(int(os.environ.get('JINA_GRPC_RECV_BYTES', '0'))) ), ) return self._unboxed_result(results)
def _prepare_streaming(self, disable, total): if total is None: total = 500 warnings.warn( 'the length of the input is unknown, the progressbar would not be accurate.' ) from docarray.array.mixins.io.pbar import get_pbar self._pbar = get_pbar(disable) os.environ['JINA_GRPC_SEND_BYTES'] = '0' os.environ['JINA_GRPC_RECV_BYTES'] = '0' self._s_task = self._pbar.add_task( ':arrow_up: Send', total=total, total_size=0, start=False ) self._r_task = self._pbar.add_task( ':arrow_down: Recv', total=total, total_size=0, start=False ) @staticmethod def _prepare_single_doc(d: 'Document'): if d.content_type in ('text', 'blob'): return d elif not d.blob and d.uri: d.load_uri_to_blob() return d elif d.tensor is not None: return d else: raise TypeError(f'unsupported input type {d!r} {d.content_type}') @staticmethod def _prepare_rank_doc(d: 'Document', _source: str = 'matches'): _get = lambda d: getattr(d, _source) if not _get(d): raise ValueError(f'`.rank()` requires every doc to have `.{_source}`') d = Client._prepare_single_doc(d) setattr(d, _source, [Client._prepare_single_doc(c) for c in _get(d)]) return d def _iter_rank_docs( self, content, _source='matches' ) -> Generator['Document', None, None]: from rich import filesize from docarray import Document if hasattr(self, '_pbar'): self._pbar.start_task(self._s_task) for c in content: if isinstance(c, Document): yield self._prepare_rank_doc(c, _source) else: raise TypeError(f'unsupported input type {c!r}') if hasattr(self, '_pbar'): self._pbar.update( self._s_task, advance=1, total_size=str( filesize.decimal( int(os.environ.get('JINA_GRPC_SEND_BYTES', '0')) ) ), ) def _get_rank_payload(self, content, kwargs): return dict( on='/rank', inputs=self._iter_rank_docs( content, _source=kwargs.get('source', 'matches') ), request_size=kwargs.get('batch_size', 8), total_docs=len(content) if hasattr(content, '__len__') else None, )
[docs] def rank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray': """Rank image-text matches according to the server CLIP model. Given a Document with nested matches, where the root is image/text and the matches is in another modality, i.e. text/image; this method ranks the matches according to the CLIP model. Each match now has a new score inside ``clip_score`` and matches are sorted descendingly according to this score. More details can be found in: https://github.com/openai/CLIP#usage :param docs: the input Documents :return: the ranked Documents in a DocumentArray. """ self._prepare_streaming( not kwargs.get('show_progress'), total=len(docs), ) results = DocumentArray() with self._pbar: self._client.post( **self._get_rank_payload(docs, kwargs), on_done=partial(self._gather_result, results=results), ) return results
[docs] async def arank(self, docs: Iterable['Document'], **kwargs) -> 'DocumentArray': from rich import filesize self._prepare_streaming( not kwargs.get('show_progress'), total=len(docs), ) results = DocumentArray() async for da in self._async_client.post(**self._get_rank_payload(docs, kwargs)): if not results: self._pbar.start_task(self._r_task) results.extend(da) self._pbar.update( self._r_task, advance=len(da), total_size=str( filesize.decimal(int(os.environ.get('JINA_GRPC_RECV_BYTES', '0'))) ), ) return results