图片解析应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

377 lines
12 KiB

  1. from typing import List, Optional, Union
  2. class Query:
  3. """
  4. Query is used to build complex queries that have more parameters than just
  5. the query string. The query string is set in the constructor, and other
  6. options have setter functions.
  7. The setter functions return the query object, so they can be chained,
  8. i.e. `Query("foo").verbatim().filter(...)` etc.
  9. """
  10. def __init__(self, query_string: str) -> None:
  11. """
  12. Create a new query object.
  13. The query string is set in the constructor, and other options have
  14. setter functions.
  15. """
  16. self._query_string: str = query_string
  17. self._offset: int = 0
  18. self._num: int = 10
  19. self._no_content: bool = False
  20. self._no_stopwords: bool = False
  21. self._fields: Optional[List[str]] = None
  22. self._verbatim: bool = False
  23. self._with_payloads: bool = False
  24. self._with_scores: bool = False
  25. self._scorer: Optional[str] = None
  26. self._filters: List = list()
  27. self._ids: Optional[List[str]] = None
  28. self._slop: int = -1
  29. self._timeout: Optional[float] = None
  30. self._in_order: bool = False
  31. self._sortby: Optional[SortbyField] = None
  32. self._return_fields: List = []
  33. self._return_fields_decode_as: dict = {}
  34. self._summarize_fields: List = []
  35. self._highlight_fields: List = []
  36. self._language: Optional[str] = None
  37. self._expander: Optional[str] = None
  38. self._dialect: Optional[int] = None
  39. def query_string(self) -> str:
  40. """Return the query string of this query only."""
  41. return self._query_string
  42. def limit_ids(self, *ids) -> "Query":
  43. """Limit the results to a specific set of pre-known document
  44. ids of any length."""
  45. self._ids = ids
  46. return self
  47. def return_fields(self, *fields) -> "Query":
  48. """Add fields to return fields."""
  49. for field in fields:
  50. self.return_field(field)
  51. return self
  52. def return_field(
  53. self,
  54. field: str,
  55. as_field: Optional[str] = None,
  56. decode_field: Optional[bool] = True,
  57. encoding: Optional[str] = "utf8",
  58. ) -> "Query":
  59. """
  60. Add a field to the list of fields to return.
  61. - **field**: The field to include in query results
  62. - **as_field**: The alias for the field
  63. - **decode_field**: Whether to decode the field from bytes to string
  64. - **encoding**: The encoding to use when decoding the field
  65. """
  66. self._return_fields.append(field)
  67. self._return_fields_decode_as[field] = encoding if decode_field else None
  68. if as_field is not None:
  69. self._return_fields += ("AS", as_field)
  70. return self
  71. def _mk_field_list(self, fields: List[str]) -> List:
  72. if not fields:
  73. return []
  74. return [fields] if isinstance(fields, str) else list(fields)
  75. def summarize(
  76. self,
  77. fields: Optional[List] = None,
  78. context_len: Optional[int] = None,
  79. num_frags: Optional[int] = None,
  80. sep: Optional[str] = None,
  81. ) -> "Query":
  82. """
  83. Return an abridged format of the field, containing only the segments of
  84. the field which contain the matching term(s).
  85. If `fields` is specified, then only the mentioned fields are
  86. summarized; otherwise all results are summarized.
  87. Server side defaults are used for each option (except `fields`)
  88. if not specified
  89. - **fields** List of fields to summarize. All fields are summarized
  90. if not specified
  91. - **context_len** Amount of context to include with each fragment
  92. - **num_frags** Number of fragments per document
  93. - **sep** Separator string to separate fragments
  94. """
  95. args = ["SUMMARIZE"]
  96. fields = self._mk_field_list(fields)
  97. if fields:
  98. args += ["FIELDS", str(len(fields))] + fields
  99. if context_len is not None:
  100. args += ["LEN", str(context_len)]
  101. if num_frags is not None:
  102. args += ["FRAGS", str(num_frags)]
  103. if sep is not None:
  104. args += ["SEPARATOR", sep]
  105. self._summarize_fields = args
  106. return self
  107. def highlight(
  108. self, fields: Optional[List[str]] = None, tags: Optional[List[str]] = None
  109. ) -> None:
  110. """
  111. Apply specified markup to matched term(s) within the returned field(s).
  112. - **fields** If specified then only those mentioned fields are
  113. highlighted, otherwise all fields are highlighted
  114. - **tags** A list of two strings to surround the match.
  115. """
  116. args = ["HIGHLIGHT"]
  117. fields = self._mk_field_list(fields)
  118. if fields:
  119. args += ["FIELDS", str(len(fields))] + fields
  120. if tags:
  121. args += ["TAGS"] + list(tags)
  122. self._highlight_fields = args
  123. return self
  124. def language(self, language: str) -> "Query":
  125. """
  126. Analyze the query as being in the specified language.
  127. :param language: The language (e.g. `chinese` or `english`)
  128. """
  129. self._language = language
  130. return self
  131. def slop(self, slop: int) -> "Query":
  132. """Allow a maximum of N intervening non matched terms between
  133. phrase terms (0 means exact phrase).
  134. """
  135. self._slop = slop
  136. return self
  137. def timeout(self, timeout: float) -> "Query":
  138. """overrides the timeout parameter of the module"""
  139. self._timeout = timeout
  140. return self
  141. def in_order(self) -> "Query":
  142. """
  143. Match only documents where the query terms appear in
  144. the same order in the document.
  145. i.e. for the query "hello world", we do not match "world hello"
  146. """
  147. self._in_order = True
  148. return self
  149. def scorer(self, scorer: str) -> "Query":
  150. """
  151. Use a different scoring function to evaluate document relevance.
  152. Default is `TFIDF`.
  153. :param scorer: The scoring function to use
  154. (e.g. `TFIDF.DOCNORM` or `BM25`)
  155. """
  156. self._scorer = scorer
  157. return self
  158. def get_args(self) -> List[str]:
  159. """Format the redis arguments for this query and return them."""
  160. args = [self._query_string]
  161. args += self._get_args_tags()
  162. args += self._summarize_fields + self._highlight_fields
  163. args += ["LIMIT", self._offset, self._num]
  164. return args
  165. def _get_args_tags(self) -> List[str]:
  166. args = []
  167. if self._no_content:
  168. args.append("NOCONTENT")
  169. if self._fields:
  170. args.append("INFIELDS")
  171. args.append(len(self._fields))
  172. args += self._fields
  173. if self._verbatim:
  174. args.append("VERBATIM")
  175. if self._no_stopwords:
  176. args.append("NOSTOPWORDS")
  177. if self._filters:
  178. for flt in self._filters:
  179. if not isinstance(flt, Filter):
  180. raise AttributeError("Did not receive a Filter object.")
  181. args += flt.args
  182. if self._with_payloads:
  183. args.append("WITHPAYLOADS")
  184. if self._scorer:
  185. args += ["SCORER", self._scorer]
  186. if self._with_scores:
  187. args.append("WITHSCORES")
  188. if self._ids:
  189. args.append("INKEYS")
  190. args.append(len(self._ids))
  191. args += self._ids
  192. if self._slop >= 0:
  193. args += ["SLOP", self._slop]
  194. if self._timeout is not None:
  195. args += ["TIMEOUT", self._timeout]
  196. if self._in_order:
  197. args.append("INORDER")
  198. if self._return_fields:
  199. args.append("RETURN")
  200. args.append(len(self._return_fields))
  201. args += self._return_fields
  202. if self._sortby:
  203. if not isinstance(self._sortby, SortbyField):
  204. raise AttributeError("Did not receive a SortByField.")
  205. args.append("SORTBY")
  206. args += self._sortby.args
  207. if self._language:
  208. args += ["LANGUAGE", self._language]
  209. if self._expander:
  210. args += ["EXPANDER", self._expander]
  211. if self._dialect:
  212. args += ["DIALECT", self._dialect]
  213. return args
  214. def paging(self, offset: int, num: int) -> "Query":
  215. """
  216. Set the paging for the query (defaults to 0..10).
  217. - **offset**: Paging offset for the results. Defaults to 0
  218. - **num**: How many results do we want
  219. """
  220. self._offset = offset
  221. self._num = num
  222. return self
  223. def verbatim(self) -> "Query":
  224. """Set the query to be verbatim, i.e. use no query expansion
  225. or stemming.
  226. """
  227. self._verbatim = True
  228. return self
  229. def no_content(self) -> "Query":
  230. """Set the query to only return ids and not the document content."""
  231. self._no_content = True
  232. return self
  233. def no_stopwords(self) -> "Query":
  234. """
  235. Prevent the query from being filtered for stopwords.
  236. Only useful in very big queries that you are certain contain
  237. no stopwords.
  238. """
  239. self._no_stopwords = True
  240. return self
  241. def with_payloads(self) -> "Query":
  242. """Ask the engine to return document payloads."""
  243. self._with_payloads = True
  244. return self
  245. def with_scores(self) -> "Query":
  246. """Ask the engine to return document search scores."""
  247. self._with_scores = True
  248. return self
  249. def limit_fields(self, *fields: List[str]) -> "Query":
  250. """
  251. Limit the search to specific TEXT fields only.
  252. - **fields**: A list of strings, case sensitive field names
  253. from the defined schema.
  254. """
  255. self._fields = fields
  256. return self
  257. def add_filter(self, flt: "Filter") -> "Query":
  258. """
  259. Add a numeric or geo filter to the query.
  260. **Currently only one of each filter is supported by the engine**
  261. - **flt**: A NumericFilter or GeoFilter object, used on a
  262. corresponding field
  263. """
  264. self._filters.append(flt)
  265. return self
  266. def sort_by(self, field: str, asc: bool = True) -> "Query":
  267. """
  268. Add a sortby field to the query.
  269. - **field** - the name of the field to sort by
  270. - **asc** - when `True`, sorting will be done in asceding order
  271. """
  272. self._sortby = SortbyField(field, asc)
  273. return self
  274. def expander(self, expander: str) -> "Query":
  275. """
  276. Add a expander field to the query.
  277. - **expander** - the name of the expander
  278. """
  279. self._expander = expander
  280. return self
  281. def dialect(self, dialect: int) -> "Query":
  282. """
  283. Add a dialect field to the query.
  284. - **dialect** - dialect version to execute the query under
  285. """
  286. self._dialect = dialect
  287. return self
  288. class Filter:
  289. def __init__(self, keyword: str, field: str, *args: List[str]) -> None:
  290. self.args = [keyword, field] + list(args)
  291. class NumericFilter(Filter):
  292. INF = "+inf"
  293. NEG_INF = "-inf"
  294. def __init__(
  295. self,
  296. field: str,
  297. minval: Union[int, str],
  298. maxval: Union[int, str],
  299. minExclusive: bool = False,
  300. maxExclusive: bool = False,
  301. ) -> None:
  302. args = [
  303. minval if not minExclusive else f"({minval}",
  304. maxval if not maxExclusive else f"({maxval}",
  305. ]
  306. Filter.__init__(self, "FILTER", field, *args)
  307. class GeoFilter(Filter):
  308. METERS = "m"
  309. KILOMETERS = "km"
  310. FEET = "ft"
  311. MILES = "mi"
  312. def __init__(
  313. self, field: str, lon: float, lat: float, radius: float, unit: str = KILOMETERS
  314. ) -> None:
  315. Filter.__init__(self, "GEOFILTER", field, lon, lat, radius, unit)
  316. class SortbyField:
  317. def __init__(self, field: str, asc=True) -> None:
  318. self.args = [field, "ASC" if asc else "DESC"]