|
|
from typing import List, Union
FIELDNAME = object()
class Limit: def __init__(self, offset: int = 0, count: int = 0) -> None: self.offset = offset self.count = count
def build_args(self): if self.count: return ["LIMIT", str(self.offset), str(self.count)] else: return []
class Reducer: """
Base reducer object for all reducers.
See the `redisearch.reducers` module for the actual reducers. """
NAME = None
def __init__(self, *args: List[str]) -> None: self._args = args self._field = None self._alias = None
def alias(self, alias: str) -> "Reducer": """
Set the alias for this reducer.
### Parameters
- **alias**: The value of the alias for this reducer. If this is the special value `aggregation.FIELDNAME` then this reducer will be aliased using the same name as the field upon which it operates. Note that using `FIELDNAME` is only possible on reducers which operate on a single field value.
This method returns the `Reducer` object making it suitable for chaining. """
if alias is FIELDNAME: if not self._field: raise ValueError("Cannot use FIELDNAME alias with no field") # Chop off initial '@' alias = self._field[1:] self._alias = alias return self
@property def args(self) -> List[str]: return self._args
class SortDirection: """
This special class is used to indicate sort direction. """
DIRSTRING = None
def __init__(self, field: str) -> None: self.field = field
class Asc(SortDirection): """
Indicate that the given field should be sorted in ascending order """
DIRSTRING = "ASC"
class Desc(SortDirection): """
Indicate that the given field should be sorted in descending order """
DIRSTRING = "DESC"
class AggregateRequest: """
Aggregation request which can be passed to `Client.aggregate`. """
def __init__(self, query: str = "*") -> None: """
Create an aggregation request. This request may then be passed to `client.aggregate()`.
In order for the request to be usable, it must contain at least one group.
- **query** Query string for filtering records.
All member methods (except `build_args()`) return the object itself, making them useful for chaining. """
self._query = query self._aggregateplan = [] self._loadfields = [] self._loadall = False self._max = 0 self._with_schema = False self._verbatim = False self._cursor = [] self._dialect = None self._add_scores = False
def load(self, *fields: List[str]) -> "AggregateRequest": """
Indicate the fields to be returned in the response. These fields are returned in addition to any others implicitly specified.
### Parameters
- **fields**: If fields not specified, all the fields will be loaded. Otherwise, fields should be given in the format of `@field`. """
if fields: self._loadfields.extend(fields) else: self._loadall = True return self
def group_by( self, fields: List[str], *reducers: Union[Reducer, List[Reducer]] ) -> "AggregateRequest": """
Specify by which fields to group the aggregation.
### Parameters
- **fields**: Fields to group by. This can either be a single string, or a list of strings. both cases, the field should be specified as `@field`. - **reducers**: One or more reducers. Reducers may be found in the `aggregation` module. """
fields = [fields] if isinstance(fields, str) else fields reducers = [reducers] if isinstance(reducers, Reducer) else reducers
ret = ["GROUPBY", str(len(fields)), *fields] for reducer in reducers: ret += ["REDUCE", reducer.NAME, str(len(reducer.args))] ret.extend(reducer.args) if reducer._alias is not None: ret += ["AS", reducer._alias]
self._aggregateplan.extend(ret) return self
def apply(self, **kwexpr) -> "AggregateRequest": """
Specify one or more projection expressions to add to each result
### Parameters
- **kwexpr**: One or more key-value pairs for a projection. The key is the alias for the projection, and the value is the projection expression itself, for example `apply(square_root="sqrt(@foo)")` """
for alias, expr in kwexpr.items(): ret = ["APPLY", expr] if alias is not None: ret += ["AS", alias] self._aggregateplan.extend(ret)
return self
def limit(self, offset: int, num: int) -> "AggregateRequest": """
Sets the limit for the most recent group or query.
If no group has been defined yet (via `group_by()`) then this sets the limit for the initial pool of results from the query. Otherwise, this limits the number of items operated on from the previous group.
Setting a limit on the initial search results may be useful when attempting to execute an aggregation on a sample of a large data set.
### Parameters
- **offset**: Result offset from which to begin paging - **num**: Number of results to return
Example of sorting the initial results:
``` AggregateRequest("@sale_amount:[10000, inf]")\ .limit(0, 10)\ .group_by("@state", r.count()) ```
Will only group by the states found in the first 10 results of the query `@sale_amount:[10000, inf]`. On the other hand,
``` AggregateRequest("@sale_amount:[10000, inf]")\ .limit(0, 1000)\ .group_by("@state", r.count()\ .limit(0, 10) ```
Will group all the results matching the query, but only return the first 10 groups.
If you only wish to return a *top-N* style query, consider using `sort_by()` instead.
"""
_limit = Limit(offset, num) self._aggregateplan.extend(_limit.build_args()) return self
def sort_by(self, *fields: List[str], **kwargs) -> "AggregateRequest": """
Indicate how the results should be sorted. This can also be used for *top-N* style queries
### Parameters
- **fields**: The fields by which to sort. This can be either a single field or a list of fields. If you wish to specify order, you can use the `Asc` or `Desc` wrapper classes. - **max**: Maximum number of results to return. This can be used instead of `LIMIT` and is also faster.
Example of sorting by `foo` ascending and `bar` descending:
``` sort_by(Asc("@foo"), Desc("@bar")) ```
Return the top 10 customers:
``` AggregateRequest()\ .group_by("@customer", r.sum("@paid").alias(FIELDNAME))\ .sort_by(Desc("@paid"), max=10) ``` """
if isinstance(fields, (str, SortDirection)): fields = [fields]
fields_args = [] for f in fields: if isinstance(f, SortDirection): fields_args += [f.field, f.DIRSTRING] else: fields_args += [f]
ret = ["SORTBY", str(len(fields_args))] ret.extend(fields_args) max = kwargs.get("max", 0) if max > 0: ret += ["MAX", str(max)]
self._aggregateplan.extend(ret) return self
def filter(self, expressions: Union[str, List[str]]) -> "AggregateRequest": """
Specify filter for post-query results using predicates relating to values in the result set.
### Parameters
- **fields**: Fields to group by. This can either be a single string, or a list of strings. """
if isinstance(expressions, str): expressions = [expressions]
for expression in expressions: self._aggregateplan.extend(["FILTER", expression])
return self
def with_schema(self) -> "AggregateRequest": """
If set, the `schema` property will contain a list of `[field, type]` entries in the result object. """
self._with_schema = True return self
def add_scores(self) -> "AggregateRequest": """
If set, includes the score as an ordinary field of the row. """
self._add_scores = True return self
def verbatim(self) -> "AggregateRequest": self._verbatim = True return self
def cursor(self, count: int = 0, max_idle: float = 0.0) -> "AggregateRequest": args = ["WITHCURSOR"] if count: args += ["COUNT", str(count)] if max_idle: args += ["MAXIDLE", str(max_idle * 1000)] self._cursor = args return self
def build_args(self) -> List[str]: # @foo:bar ... ret = [self._query]
if self._with_schema: ret.append("WITHSCHEMA")
if self._verbatim: ret.append("VERBATIM")
if self._add_scores: ret.append("ADDSCORES")
if self._cursor: ret += self._cursor
if self._loadall: ret.append("LOAD") ret.append("*") elif self._loadfields: ret.append("LOAD") ret.append(str(len(self._loadfields))) ret.extend(self._loadfields)
if self._dialect: ret.extend(["DIALECT", self._dialect])
ret.extend(self._aggregateplan)
return ret
def dialect(self, dialect: int) -> "AggregateRequest": """
Add a dialect field to the aggregate command.
- **dialect** - dialect version to execute the query under """
self._dialect = dialect return self
class Cursor: def __init__(self, cid: int) -> None: self.cid = cid self.max_idle = 0 self.count = 0
def build_args(self): args = [str(self.cid)] if self.max_idle: args += ["MAXIDLE", str(self.max_idle)] if self.count: args += ["COUNT", str(self.count)] return args
class AggregateResult: def __init__(self, rows, cursor: Cursor, schema) -> None: self.rows = rows self.cursor = cursor self.schema = schema
def __repr__(self) -> (str, str): cid = self.cursor.cid if self.cursor else -1 return ( f"<{self.__class__.__name__} at 0x{id(self):x} " f"Rows={len(self.rows)}, Cursor={cid}>" )
|