Browse Source

新增检索逻辑

master
maojian 1 month ago
parent
commit
564e0fbad6
  1. 55
      src/main/java/com/bw/search/controller/CharacterController.java
  2. 24
      src/main/java/com/bw/search/entity/TermEntity.java
  3. 28
      src/main/java/com/bw/search/service/CharacterService.java
  4. 322
      src/main/java/com/bw/search/service/impl/CharacterServiceImpl.java
  5. 163
      src/main/java/com/bw/search/service/impl/RagSearchServiceImpl.java

55
src/main/java/com/bw/search/controller/CharacterController.java

@ -0,0 +1,55 @@
package com.bw.search.controller;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.CrossOrigin;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import com.bw.search.common.Res;
import com.bw.search.entity.TermEntity;
import com.bw.search.service.CharacterService;
import lombok.extern.slf4j.Slf4j;
/**
* 人物操作控制层
* @author jian.mao
* @date 2026年1月6日
* @description
*/
@RestController
@CrossOrigin
@RequestMapping("/api")
@Slf4j
public class CharacterController {
@Autowired
private CharacterService characterService;
/**
* 查询专家数据
*/
@GetMapping("/characters")
public Res<?> getCharacter(
@RequestParam(value = "page", defaultValue = "1", required = false) Integer page,
@RequestParam(value = "size", defaultValue = "10", required = false) Integer size) {
return characterService.getCharacter(page,size);
}
/**
* 根据条件查询人物数据
* @param termEntity
* @return
*/
@PostMapping("/charactersbyterm")
public Res<?> getCharacterByTerm(@RequestBody TermEntity termEntity){
return characterService.getCharacterByTerm(termEntity);
}
}

24
src/main/java/com/bw/search/entity/TermEntity.java

@ -0,0 +1,24 @@
package com.bw.search.entity;
import java.util.List;
import lombok.Data;
/**
* 查询条件
* @author jian.mao
* @date 2026年1月14日
* @description
*/
@Data
public class TermEntity {
private List<String> fullNames;
private List<String> affiliations;
private List<String> countrys;
private List<String> positions;
private List<String> researchFocus;
private Integer size;
private Integer page;
}

28
src/main/java/com/bw/search/service/CharacterService.java

@ -0,0 +1,28 @@
package com.bw.search.service;
import com.bw.search.common.Res;
import com.bw.search.entity.TermEntity;
/**
* 人物操作业务层
* @author jian.mao
* @date 2026年1月6日
* @description
*/
public interface CharacterService {
/**
* 获取人物业务层接口
* @param page
* @param size
*/
public Res<?> getCharacter(Integer page, Integer size);
/**
* 获取人物数据 by 条件
* @param termEntity
* @return
*/
public Res<?> getCharacterByTerm(TermEntity termEntity);
}

322
src/main/java/com/bw/search/service/impl/CharacterServiceImpl.java

@ -0,0 +1,322 @@
package com.bw.search.service.impl;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.annotation.Resource;
import org.apache.http.HttpResponse;
import org.apache.http.auth.AuthScope;
import org.apache.http.auth.UsernamePasswordCredentials;
import org.apache.http.client.CredentialsProvider;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.ContentType;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.BasicCredentialsProvider;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.util.EntityUtils;
import org.springframework.stereotype.Service;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.bw.search.common.Res;
import com.bw.search.config.EsConfig;
import com.bw.search.entity.TermEntity;
import com.bw.search.service.CharacterService;
import lombok.extern.slf4j.Slf4j;
/**
* 人物操作业务接口实现类
* @author jian.mao
* @date 2026年1月6日
* @description
*/
@Service
@Slf4j
public class CharacterServiceImpl implements CharacterService {
@Resource
private EsConfig esConfig;
@Override
public Res<?> getCharacter(Integer page, Integer size) {
// TODO Auto-generated method stub
CredentialsProvider credentialsProvider = new BasicCredentialsProvider();
credentialsProvider.setCredentials(
AuthScope.ANY,
new UsernamePasswordCredentials(esConfig.getUsername(),esConfig.getPassword())
);
CloseableHttpClient httpClient = null;
try {
if (esConfig.getUsername() != null && !esConfig.getUsername().trim().equals("")) {
httpClient = HttpClients.custom()
.setDefaultCredentialsProvider(credentialsProvider)
.build();
} else {
httpClient = HttpClients.custom().build();
}
// ================== 构建查询 DSL ==================
int from = (page - 1) * size;
Map<String, Object> query = new HashMap<String, Object>();
query.put("from", from);
query.put("size", size);
// must 条件
List<Map<String, Object>> mustList = new ArrayList<Map<String, Object>>();
Map<String, Object> bool = new HashMap<String, Object>();
bool.put("must", mustList);
Map<String, Object> queryBody = new HashMap<String, Object>();
queryBody.put("bool", bool);
query.put("query", queryBody);
// sort
List<Map<String, Object>> sortList = new ArrayList<Map<String, Object>>();
Map<String, Object> order = new HashMap<String, Object>();
order.put("order", "desc");
Map<String, Object> sortField = new HashMap<String, Object>();
sortField.put("collectionTime", order);
sortList.add(sortField);
query.put("sort", sortList);
log.info("查询条件:{}",JSONObject.toJSONString(query));
// ================== 发起 HTTP 请求 ==================
StringBuffer host = new StringBuffer();
host.append(esConfig.getHost())
.append("/")
.append(esConfig.getIndex())
.append("/_search");
HttpPost httpPost = new HttpPost(host.toString());
httpPost.setHeader("Content-Type", "application/json");
StringEntity entity = new StringEntity(
JSONObject.toJSONString(query),
ContentType.APPLICATION_JSON
);
httpPost.setEntity(entity);
HttpResponse response = httpClient.execute(httpPost);
int statusCode = response.getStatusLine().getStatusCode();
String responseBody = EntityUtils.toString(response.getEntity(), "UTF-8");
if (statusCode != 200) {
log.error("ES 查询失败 status={}, body={}", statusCode, responseBody);
return Res.fail("ES 查询失败");
}
// ================== 解析返回 ==================
JSONObject json = JSONObject.parseObject(responseBody);
JSONObject hits = json.getJSONObject("hits");
Long total = hits.getJSONObject("total").getLong("value");
List<Map<String, Object>> list = new ArrayList<Map<String, Object>>();
JSONArray hitList = hits.getJSONArray("hits");
for (int i = 0; i < hitList.size(); i++) {
JSONObject source = hitList.getJSONObject(i).getJSONObject("_source");
list.add(source);
}
Map<String, Object> result = new HashMap<String, Object>();
result.put("page", page);
result.put("size", size);
result.put("total", total);
result.put("list", list);
return Res.ok(result);
} catch (Exception e) {
return Res.fail("查询任务失败");
} finally {
if (httpClient != null) {
try {
httpClient.close();
} catch (Exception ignored) {}
}
}
}
@Override
public Res<?> getCharacterByTerm(TermEntity term) {
CredentialsProvider credentialsProvider = new BasicCredentialsProvider();
credentialsProvider.setCredentials(
AuthScope.ANY,
new UsernamePasswordCredentials(esConfig.getUsername(), esConfig.getPassword())
);
CloseableHttpClient httpClient = null;
try {
if (esConfig.getUsername() != null && !"".equals(esConfig.getUsername().trim())) {
httpClient = HttpClients.custom()
.setDefaultCredentialsProvider(credentialsProvider)
.build();
} else {
httpClient = HttpClients.custom().build();
}
// ============ 分页 ============
int page = term.getPage() == null || term.getPage() < 1 ? 1 : term.getPage();
int size = term.getSize() == null || term.getSize() < 1 ? 10 : term.getSize();
int from = (page - 1) * size;
Map<String, Object> query = new HashMap<>();
query.put("from", from);
query.put("size", size);
// ============ bool ============
List<Map<String, Object>> mustList = new ArrayList<>();
List<Map<String, Object>> filterList = new ArrayList<>();
Map<String, Object> bool = new HashMap<>();
bool.put("must", mustList);
bool.put("filter", filterList);
// ================= fullName 模糊包含 =================
if (term.getFullNames() != null && !term.getFullNames().isEmpty()) {
List<Map<String, Object>> shouldList = new ArrayList<>();
for (String name : term.getFullNames()) {
if (name != null && !"".equals(name.trim())) {
Map<String, Object> match = new HashMap<>();
Map<String, Object> field = new HashMap<>();
field.put("fullName", name);
match.put("match", field);
shouldList.add(match);
}
}
if (!shouldList.isEmpty()) {
Map<String, Object> shouldBool = new HashMap<>();
shouldBool.put("should", shouldList);
shouldBool.put("minimum_should_match", 1);
Map<String, Object> shouldQuery = new HashMap<>();
shouldQuery.put("bool", shouldBool);
mustList.add(shouldQuery);
}
}
// ================= affiliation 精确 =================
if (needFilter(term.getAffiliations())) {
Map<String, Object> terms = new HashMap<>();
terms.put("affiliation", term.getAffiliations());
Map<String, Object> termsQuery = new HashMap<>();
termsQuery.put("terms", terms);
filterList.add(termsQuery);
}
// ================= country -> geographicInfo =================
if (needFilter(term.getCountrys())) {
Map<String, Object> terms = new HashMap<>();
terms.put("geographicInfo", term.getCountrys());
Map<String, Object> termsQuery = new HashMap<>();
termsQuery.put("terms", terms);
filterList.add(termsQuery);
}
// ================= position =================
if (needFilter(term.getPositions())) {
Map<String, Object> terms = new HashMap<>();
terms.put("position", term.getPositions());
Map<String, Object> termsQuery = new HashMap<>();
termsQuery.put("terms", terms);
filterList.add(termsQuery);
}
// ================= researchFocus数组精确匹配 =================
if (needFilter(term.getResearchFocus())) {
Map<String, Object> terms = new HashMap<>();
terms.put("researchFocus", term.getResearchFocus());
Map<String, Object> termsQuery = new HashMap<>();
termsQuery.put("terms", terms);
filterList.add(termsQuery);
}
Map<String, Object> queryBody = new HashMap<>();
queryBody.put("bool", bool);
query.put("query", queryBody);
// ================= 排序 =================
List<Map<String, Object>> sortList = new ArrayList<>();
Map<String, Object> order = new HashMap<>();
order.put("order", "desc");
Map<String, Object> sortField = new HashMap<>();
sortField.put("collectionTime", order);
sortList.add(sortField);
query.put("sort", sortList);
log.info("查询条件:{}",JSONObject.toJSONString(query));
// ================= 请求 ES =================
String url = esConfig.getHost() + "/" + esConfig.getIndex() + "/_search";
HttpPost httpPost = new HttpPost(url);
httpPost.setHeader("Content-Type", "application/json");
httpPost.setEntity(new StringEntity(JSONObject.toJSONString(query), ContentType.APPLICATION_JSON));
HttpResponse response = httpClient.execute(httpPost);
String responseBody = EntityUtils.toString(response.getEntity(), "UTF-8");
if (response.getStatusLine().getStatusCode() != 200) {
log.error("ES查询失败 body={}", responseBody);
return Res.fail("ES查询失败");
}
// ================= 解析返回 =================
JSONObject json = JSONObject.parseObject(responseBody);
JSONObject hits = json.getJSONObject("hits");
Long total = hits.getJSONObject("total").getLong("value");
List<Map<String, Object>> list = new ArrayList<>();
JSONArray hitList = hits.getJSONArray("hits");
for (int i = 0; i < hitList.size(); i++) {
list.add(hitList.getJSONObject(i).getJSONObject("_source"));
}
Map<String, Object> result = new HashMap<>();
result.put("page", page);
result.put("size", size);
result.put("total", total);
result.put("list", list);
return Res.ok(result);
} catch (Exception e) {
log.error("ES查询异常", e);
return Res.fail("查询失败");
} finally {
try {
if (httpClient != null) httpClient.close();
} catch (Exception ignored) {}
}
}
private boolean needFilter(List<String> list) {
if (list == null || list.isEmpty()) return false;
if (list.size() == 1 && "all".equalsIgnoreCase(list.get(0))) {
return false;
}
return true;
}
}

163
src/main/java/com/bw/search/service/impl/RagSearchServiceImpl.java

@ -1,14 +1,31 @@
package com.bw.search.service.impl; package com.bw.search.service.impl;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map; import java.util.Map;
import javax.annotation.Resource;
import org.apache.http.HttpResponse;
import org.apache.http.auth.AuthScope;
import org.apache.http.auth.UsernamePasswordCredentials;
import org.apache.http.client.CredentialsProvider;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.ContentType;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.BasicCredentialsProvider;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.util.EntityUtils;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.JSONObject;
import com.bw.search.cache.ConfigCache; import com.bw.search.cache.ConfigCache;
import com.bw.search.common.Res; import com.bw.search.common.Res;
import com.bw.search.config.EsConfig;
import com.bw.search.entity.Constants; import com.bw.search.entity.Constants;
import com.bw.search.entity.SearchResponse;
import com.bw.search.service.RagSearchService; import com.bw.search.service.RagSearchService;
import com.bw.search.utils.DateUtil; import com.bw.search.utils.DateUtil;
@ -18,6 +35,26 @@ import lombok.extern.slf4j.Slf4j;
@Slf4j @Slf4j
public class RagSearchServiceImpl implements RagSearchService { public class RagSearchServiceImpl implements RagSearchService {
@Resource
private EsConfig esConfig;
// @Override
// public Res<?> search(String dataJson) {
// log.info("向量检索参数:{}",dataJson);
// //转换对象
// JSONObject parseObject = JSONObject.parseObject(dataJson);
// String id = parseObject.getString(Constants.ID);
// Map<String, Object> knowResult = getKnowledge(id);
// if(knowResult == null) {
// log.error("向量检索失败!");
// return Res.fail("知识库获取失败!");
// }
// log.info("知识库结果已获取:{}",JSONObject.toJSONString(knowResult));
// //响应体数据
// SearchResponse searchResponse = new SearchResponse();
// searchResponse.setIds(JSONObject.parseArray((String)knowResult.get(Constants.IDS), String.class));
// return Res.ok(searchResponse);
// }
@Override @Override
public Res<?> search(String dataJson) { public Res<?> search(String dataJson) {
log.info("向量检索参数:{}",dataJson); log.info("向量检索参数:{}",dataJson);
@ -30,12 +67,128 @@ public class RagSearchServiceImpl implements RagSearchService {
return Res.fail("知识库获取失败!"); return Res.fail("知识库获取失败!");
} }
log.info("知识库结果已获取:{}",JSONObject.toJSONString(knowResult)); log.info("知识库结果已获取:{}",JSONObject.toJSONString(knowResult));
//响应体数据
SearchResponse searchResponse = new SearchResponse();
searchResponse.setIds(JSONObject.parseArray((String)knowResult.get(Constants.IDS), String.class));
return Res.ok(searchResponse);
//获取完整数据
List<String> ids = JSONObject.parseArray((String)knowResult.get(Constants.IDS), String.class);
return getCharacterByIds(ids);
}
public Res<?> getCharacterByIds(List<String> ids) {
if (ids == null || ids.isEmpty()) {
return Res.fail("ids 不能为空");
}
CredentialsProvider credentialsProvider = new BasicCredentialsProvider();
credentialsProvider.setCredentials(
AuthScope.ANY,
new UsernamePasswordCredentials(
esConfig.getUsername(),
esConfig.getPassword()
)
);
CloseableHttpClient httpClient = null;
try {
if (esConfig.getUsername() != null && !esConfig.getUsername().trim().isEmpty()) {
httpClient = HttpClients.custom()
.setDefaultCredentialsProvider(credentialsProvider)
.build();
} else {
httpClient = HttpClients.custom().build();
}
// ================== 构建查询 DSL ==================
Map<String, Object> query = new HashMap<>();
// must 条件
List<Map<String, Object>> mustList = new ArrayList<>();
// terms 查询 knowId
Map<String, Object> terms = new HashMap<>();
Map<String, Object> termsField = new HashMap<>();
// 如果 knowId text请改成 knowId.keyword
termsField.put("knowId", ids);
terms.put("terms", termsField);
mustList.add(terms);
Map<String, Object> bool = new HashMap<>();
bool.put("must", mustList);
Map<String, Object> queryBody = new HashMap<>();
queryBody.put("bool", bool);
query.put("query", queryBody);
// sort
List<Map<String, Object>> sortList = new ArrayList<>();
Map<String, Object> order = new HashMap<>();
order.put("order", "desc");
Map<String, Object> sortField = new HashMap<>();
sortField.put("collectionTime", order);
sortList.add(sortField);
query.put("sort", sortList);
// ================== 发起 HTTP 请求 ==================
String url = esConfig.getHost()
+ "/"
+ esConfig.getIndex()
+ "/_search";
HttpPost httpPost = new HttpPost(url);
httpPost.setHeader("Content-Type", "application/json");
StringEntity entity = new StringEntity(
JSONObject.toJSONString(query),
ContentType.APPLICATION_JSON
);
httpPost.setEntity(entity);
HttpResponse response = httpClient.execute(httpPost);
int statusCode = response.getStatusLine().getStatusCode();
String responseBody = EntityUtils.toString(response.getEntity(), "UTF-8");
if (statusCode != 200) {
log.error("ES 查询失败 status={}, body={}", statusCode, responseBody);
return Res.fail("ES 查询失败");
} }
// ================== 解析返回 ==================
JSONObject json = JSONObject.parseObject(responseBody);
JSONObject hits = json.getJSONObject("hits");
Long total = hits.getJSONObject("total").getLong("value");
List<Map<String, Object>> list = new ArrayList<>();
JSONArray hitList = hits.getJSONArray("hits");
for (int i = 0; i < hitList.size(); i++) {
JSONObject source = hitList
.getJSONObject(i)
.getJSONObject("_source");
list.add(source);
}
return Res.ok(list);
} catch (Exception e) {
log.error("ES 查询异常", e);
return Res.fail("查询失败");
} finally {
if (httpClient != null) {
try {
httpClient.close();
} catch (Exception ignored) {}
}
}
}
/** /**

Loading…
Cancel
Save