图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

383 lines
11 KiB

  1. from typing import List, Union
  2. FIELDNAME = object()
  3. class Limit:
  4. def __init__(self, offset: int = 0, count: int = 0) -> None:
  5. self.offset = offset
  6. self.count = count
  7. def build_args(self):
  8. if self.count:
  9. return ["LIMIT", str(self.offset), str(self.count)]
  10. else:
  11. return []
  12. class Reducer:
  13. """
  14. Base reducer object for all reducers.
  15. See the `redisearch.reducers` module for the actual reducers.
  16. """
  17. NAME = None
  18. def __init__(self, *args: List[str]) -> None:
  19. self._args = args
  20. self._field = None
  21. self._alias = None
  22. def alias(self, alias: str) -> "Reducer":
  23. """
  24. Set the alias for this reducer.
  25. ### Parameters
  26. - **alias**: The value of the alias for this reducer. If this is the
  27. special value `aggregation.FIELDNAME` then this reducer will be
  28. aliased using the same name as the field upon which it operates.
  29. Note that using `FIELDNAME` is only possible on reducers which
  30. operate on a single field value.
  31. This method returns the `Reducer` object making it suitable for
  32. chaining.
  33. """
  34. if alias is FIELDNAME:
  35. if not self._field:
  36. raise ValueError("Cannot use FIELDNAME alias with no field")
  37. # Chop off initial '@'
  38. alias = self._field[1:]
  39. self._alias = alias
  40. return self
  41. @property
  42. def args(self) -> List[str]:
  43. return self._args
  44. class SortDirection:
  45. """
  46. This special class is used to indicate sort direction.
  47. """
  48. DIRSTRING = None
  49. def __init__(self, field: str) -> None:
  50. self.field = field
  51. class Asc(SortDirection):
  52. """
  53. Indicate that the given field should be sorted in ascending order
  54. """
  55. DIRSTRING = "ASC"
  56. class Desc(SortDirection):
  57. """
  58. Indicate that the given field should be sorted in descending order
  59. """
  60. DIRSTRING = "DESC"
  61. class AggregateRequest:
  62. """
  63. Aggregation request which can be passed to `Client.aggregate`.
  64. """
  65. def __init__(self, query: str = "*") -> None:
  66. """
  67. Create an aggregation request. This request may then be passed to
  68. `client.aggregate()`.
  69. In order for the request to be usable, it must contain at least one
  70. group.
  71. - **query** Query string for filtering records.
  72. All member methods (except `build_args()`)
  73. return the object itself, making them useful for chaining.
  74. """
  75. self._query = query
  76. self._aggregateplan = []
  77. self._loadfields = []
  78. self._loadall = False
  79. self._max = 0
  80. self._with_schema = False
  81. self._verbatim = False
  82. self._cursor = []
  83. self._dialect = None
  84. self._add_scores = False
  85. def load(self, *fields: List[str]) -> "AggregateRequest":
  86. """
  87. Indicate the fields to be returned in the response. These fields are
  88. returned in addition to any others implicitly specified.
  89. ### Parameters
  90. - **fields**: If fields not specified, all the fields will be loaded.
  91. Otherwise, fields should be given in the format of `@field`.
  92. """
  93. if fields:
  94. self._loadfields.extend(fields)
  95. else:
  96. self._loadall = True
  97. return self
  98. def group_by(
  99. self, fields: List[str], *reducers: Union[Reducer, List[Reducer]]
  100. ) -> "AggregateRequest":
  101. """
  102. Specify by which fields to group the aggregation.
  103. ### Parameters
  104. - **fields**: Fields to group by. This can either be a single string,
  105. or a list of strings. both cases, the field should be specified as
  106. `@field`.
  107. - **reducers**: One or more reducers. Reducers may be found in the
  108. `aggregation` module.
  109. """
  110. fields = [fields] if isinstance(fields, str) else fields
  111. reducers = [reducers] if isinstance(reducers, Reducer) else reducers
  112. ret = ["GROUPBY", str(len(fields)), *fields]
  113. for reducer in reducers:
  114. ret += ["REDUCE", reducer.NAME, str(len(reducer.args))]
  115. ret.extend(reducer.args)
  116. if reducer._alias is not None:
  117. ret += ["AS", reducer._alias]
  118. self._aggregateplan.extend(ret)
  119. return self
  120. def apply(self, **kwexpr) -> "AggregateRequest":
  121. """
  122. Specify one or more projection expressions to add to each result
  123. ### Parameters
  124. - **kwexpr**: One or more key-value pairs for a projection. The key is
  125. the alias for the projection, and the value is the projection
  126. expression itself, for example `apply(square_root="sqrt(@foo)")`
  127. """
  128. for alias, expr in kwexpr.items():
  129. ret = ["APPLY", expr]
  130. if alias is not None:
  131. ret += ["AS", alias]
  132. self._aggregateplan.extend(ret)
  133. return self
  134. def limit(self, offset: int, num: int) -> "AggregateRequest":
  135. """
  136. Sets the limit for the most recent group or query.
  137. If no group has been defined yet (via `group_by()`) then this sets
  138. the limit for the initial pool of results from the query. Otherwise,
  139. this limits the number of items operated on from the previous group.
  140. Setting a limit on the initial search results may be useful when
  141. attempting to execute an aggregation on a sample of a large data set.
  142. ### Parameters
  143. - **offset**: Result offset from which to begin paging
  144. - **num**: Number of results to return
  145. Example of sorting the initial results:
  146. ```
  147. AggregateRequest("@sale_amount:[10000, inf]")\
  148. .limit(0, 10)\
  149. .group_by("@state", r.count())
  150. ```
  151. Will only group by the states found in the first 10 results of the
  152. query `@sale_amount:[10000, inf]`. On the other hand,
  153. ```
  154. AggregateRequest("@sale_amount:[10000, inf]")\
  155. .limit(0, 1000)\
  156. .group_by("@state", r.count()\
  157. .limit(0, 10)
  158. ```
  159. Will group all the results matching the query, but only return the
  160. first 10 groups.
  161. If you only wish to return a *top-N* style query, consider using
  162. `sort_by()` instead.
  163. """
  164. _limit = Limit(offset, num)
  165. self._aggregateplan.extend(_limit.build_args())
  166. return self
  167. def sort_by(self, *fields: List[str], **kwargs) -> "AggregateRequest":
  168. """
  169. Indicate how the results should be sorted. This can also be used for
  170. *top-N* style queries
  171. ### Parameters
  172. - **fields**: The fields by which to sort. This can be either a single
  173. field or a list of fields. If you wish to specify order, you can
  174. use the `Asc` or `Desc` wrapper classes.
  175. - **max**: Maximum number of results to return. This can be
  176. used instead of `LIMIT` and is also faster.
  177. Example of sorting by `foo` ascending and `bar` descending:
  178. ```
  179. sort_by(Asc("@foo"), Desc("@bar"))
  180. ```
  181. Return the top 10 customers:
  182. ```
  183. AggregateRequest()\
  184. .group_by("@customer", r.sum("@paid").alias(FIELDNAME))\
  185. .sort_by(Desc("@paid"), max=10)
  186. ```
  187. """
  188. if isinstance(fields, (str, SortDirection)):
  189. fields = [fields]
  190. fields_args = []
  191. for f in fields:
  192. if isinstance(f, SortDirection):
  193. fields_args += [f.field, f.DIRSTRING]
  194. else:
  195. fields_args += [f]
  196. ret = ["SORTBY", str(len(fields_args))]
  197. ret.extend(fields_args)
  198. max = kwargs.get("max", 0)
  199. if max > 0:
  200. ret += ["MAX", str(max)]
  201. self._aggregateplan.extend(ret)
  202. return self
  203. def filter(self, expressions: Union[str, List[str]]) -> "AggregateRequest":
  204. """
  205. Specify filter for post-query results using predicates relating to
  206. values in the result set.
  207. ### Parameters
  208. - **fields**: Fields to group by. This can either be a single string,
  209. or a list of strings.
  210. """
  211. if isinstance(expressions, str):
  212. expressions = [expressions]
  213. for expression in expressions:
  214. self._aggregateplan.extend(["FILTER", expression])
  215. return self
  216. def with_schema(self) -> "AggregateRequest":
  217. """
  218. If set, the `schema` property will contain a list of `[field, type]`
  219. entries in the result object.
  220. """
  221. self._with_schema = True
  222. return self
  223. def add_scores(self) -> "AggregateRequest":
  224. """
  225. If set, includes the score as an ordinary field of the row.
  226. """
  227. self._add_scores = True
  228. return self
  229. def verbatim(self) -> "AggregateRequest":
  230. self._verbatim = True
  231. return self
  232. def cursor(self, count: int = 0, max_idle: float = 0.0) -> "AggregateRequest":
  233. args = ["WITHCURSOR"]
  234. if count:
  235. args += ["COUNT", str(count)]
  236. if max_idle:
  237. args += ["MAXIDLE", str(max_idle * 1000)]
  238. self._cursor = args
  239. return self
  240. def build_args(self) -> List[str]:
  241. # @foo:bar ...
  242. ret = [self._query]
  243. if self._with_schema:
  244. ret.append("WITHSCHEMA")
  245. if self._verbatim:
  246. ret.append("VERBATIM")
  247. if self._add_scores:
  248. ret.append("ADDSCORES")
  249. if self._cursor:
  250. ret += self._cursor
  251. if self._loadall:
  252. ret.append("LOAD")
  253. ret.append("*")
  254. elif self._loadfields:
  255. ret.append("LOAD")
  256. ret.append(str(len(self._loadfields)))
  257. ret.extend(self._loadfields)
  258. if self._dialect:
  259. ret.extend(["DIALECT", self._dialect])
  260. ret.extend(self._aggregateplan)
  261. return ret
  262. def dialect(self, dialect: int) -> "AggregateRequest":
  263. """
  264. Add a dialect field to the aggregate command.
  265. - **dialect** - dialect version to execute the query under
  266. """
  267. self._dialect = dialect
  268. return self
  269. class Cursor:
  270. def __init__(self, cid: int) -> None:
  271. self.cid = cid
  272. self.max_idle = 0
  273. self.count = 0
  274. def build_args(self):
  275. args = [str(self.cid)]
  276. if self.max_idle:
  277. args += ["MAXIDLE", str(self.max_idle)]
  278. if self.count:
  279. args += ["COUNT", str(self.count)]
  280. return args
  281. class AggregateResult:
  282. def __init__(self, rows, cursor: Cursor, schema) -> None:
  283. self.rows = rows
  284. self.cursor = cursor
  285. self.schema = schema
  286. def __repr__(self) -> (str, str):
  287. cid = self.cursor.cid if self.cursor else -1
  288. return (
  289. f"<{self.__class__.__name__} at 0x{id(self):x} "
  290. f"Rows={len(self.rows)}, Cursor={cid}>"
  291. )