Browse Source

新增deepseek本地模型接入

master
maojian 2 months ago
parent
commit
f5ca987ed9
  1. 3
      src/main/java/com/bw/qanda/cache/ConfigCache.java
  2. 11
      src/main/java/com/bw/qanda/controller/QandaController.java
  3. 38
      src/main/java/com/bw/qanda/handler/MainHandler.java
  4. 5
      src/main/java/com/bw/qanda/service/QandATaskService.java
  5. 6
      src/main/java/com/bw/qanda/service/QandaService.java
  6. 139
      src/main/java/com/bw/qanda/service/impl/QandATaskServiceImpl.java
  7. 33
      src/main/java/com/bw/qanda/service/impl/QandaServiceImpl.java
  8. 2
      src/main/java/com/bw/qanda/utils/DownLoadUtil.java
  9. 10
      src/main/java/com/bw/qanda/utils/GPTResultParseUtil.java
  10. 4
      src/main/resources/application.yml

3
src/main/java/com/bw/qanda/cache/ConfigCache.java

@ -17,8 +17,7 @@ public class ConfigCache {
public static boolean isStart = true;
/*****任务队列*****/
public static LinkedBlockingDeque<Map<String, Object>> taskQueue = new LinkedBlockingDeque<Map<String,Object>>();
public static LinkedBlockingDeque<Map<String, Object>> videoCommitTaskQueue = new LinkedBlockingDeque<Map<String,Object>>();
public static LinkedBlockingDeque<Map<String, Object>> videoResultTaskQueue = new LinkedBlockingDeque<Map<String,Object>>();
public static LinkedBlockingDeque<Map<String, Object>> localTaskQueue = new LinkedBlockingDeque<Map<String,Object>>();
/**

11
src/main/java/com/bw/qanda/controller/QandaController.java

@ -33,6 +33,17 @@ public class QandaController {
return response;
}
/**
* 本地模型调用
* @param dataJson
* @return
*/
@PostMapping("/putLocalQuestion")
@ResponseBody
public String putLocalQuestion(@RequestBody String dataJson){
String response = qandaService.putLocalQuestion(dataJson);
return response;
}
@RequestMapping(value = "/hello", method = RequestMethod.GET)
@ResponseBody
public String hello(String param, String token) {

38
src/main/java/com/bw/qanda/handler/MainHandler.java

@ -43,6 +43,8 @@ public class MainHandler implements ApplicationRunner {
@Value("${task.task-queue-path}")
private String taskPath;
@Value("${task.local-task-queue-path}")
private String localTaskPath;
@Resource
private QandATaskService qandATaskService;
@Resource
@ -95,8 +97,28 @@ public class MainHandler implements ApplicationRunner {
});
textConsumerThread.start();
log.info("问答模型任务消费线程启动-----");
//消费文本翻译任务队列数据
Thread localTextConsumerThread = new Thread(() -> {
while (true) {
try {
log.info("本地模型任务队列长度:{}",ConfigCache.localTaskQueue.size());
// 从队列中获取任务
Map<String, Object> task = ConfigCache.localTaskQueue.take();
// 提交给线程池执行
executor.execute(() -> localQandAExec(task));
} catch (InterruptedException e) {
// 恢复中断状态
Thread.currentThread().interrupt();
log.error("创建任务消费线程被中断");
break;
}
}
});
localTextConsumerThread.start();
log.info("本地问答模型任务消费线程启动-----");
//加载任务
readTask(taskPath, ConfigCache.taskQueue);
readTask(localTaskPath, ConfigCache.localTaskQueue);
//钩子拉起
waitDown();
}
@ -104,6 +126,9 @@ public class MainHandler implements ApplicationRunner {
public void qandAExec(Map<String, Object> task) {
qandATaskService.qandA(task);
}
public void localQandAExec(Map<String, Object> task) {
qandATaskService.localQandA(task);
}
@SuppressWarnings("unchecked")
public static void readTask(String path, LinkedBlockingDeque<Map<String, Object>> queue) {
@ -160,5 +185,18 @@ public class MainHandler implements ApplicationRunner {
break;
}
}
while (true) {
if (ConfigCache.localTaskQueue.size() > 0) {
try {
Map<String, Object> task = ConfigCache.localTaskQueue.take();
FileUtil.writeFile(localTaskPath, JSONObject.toJSONString(task));
} catch (InterruptedException e) {
e.printStackTrace();
}
} else {
log.info("taskQueue write is file end");
break;
}
}
}
}

5
src/main/java/com/bw/qanda/service/QandATaskService.java

@ -15,4 +15,9 @@ public interface QandATaskService {
* @param task
*/
public void qandA(Map<String, Object> task);
/**
* 本地问答执行方法
* @param task
*/
public void localQandA(Map<String, Object> task);
}

6
src/main/java/com/bw/qanda/service/QandaService.java

@ -14,5 +14,11 @@ public interface QandaService {
* @return
*/
public String putQuestion(String dataJson);
/**
* 本地问答
* @param dataJson
* @return
*/
public String putLocalQuestion(String dataJson);
}

139
src/main/java/com/bw/qanda/service/impl/QandATaskServiceImpl.java

@ -35,6 +35,10 @@ public class QandATaskServiceImpl implements QandATaskService {
private SpringBootKafka springBootKafka;
@Value("${customize-kafka.producer.topic}")
private String topic;
@Value("${localModel.apiKey}")
private String localQaKey;
@Value("${localModel.url}")
private String localModelUrl;
@Override
public void qandA(Map<String, Object> task) {
try {
@ -144,6 +148,110 @@ public class QandATaskServiceImpl implements QandATaskService {
}
}
@Override
public void localQandA(Map<String, Object> task) {
try {
log.info("本地任务:{}",JSONObject.toJSONString(task));
//结果接收
Map<String, Object> result = new HashMap<String, Object>(16);
//结果内容
StringBuffer chatContent = new StringBuffer();
//输入配置
Map<String, Object> input = (Map<String, Object>) task.get(Constants.INPUT);
//输出配置
Map<String, Object> output = (Map<String, Object>) task.get(Constants.OUTPUT);
//数据源
Map<String, Object> data = (Map<String, Object>) task.get(Constants.DATA);
int scenesId = (int) task.get(Constants.SCENES_ID);
int version = (int) task.get(Constants.VERSION);
String pauseKey = scenesId + "_" + version;
if (!PauseTool.CACHE.containsKey(pauseKey)) {
log.info("流程:{}的版本:{}已失效,任务跳过", scenesId, version);
return;
}
//fieldType自定义输出字段 0 关闭1-开启如果开启则拼接form到output里如果关闭则取默认的output拼接
int fieldType = (int) input.get(Constants.FIELDTYPE);
Float temperature = Float.valueOf(input.get(Constants.TEMPERATURE).toString());
Float topP = Float.valueOf(input.get(Constants.TOP_P).toString());
List<Map<String, Object>> prompt = (List<Map<String, Object>>) input.get(Constants.PROMPT);
String answerStr = localQandaRequest(temperature, topP, prompt, data);
log.info("answerStr:" + answerStr);
Map<String, Object> answer = JSONObject.parseObject(answerStr);
try {
//请求成功正常解析
Map<String, Object> message = (Map<String, Object>) answer.get(Constants.MESSAGE);
chatContent.append(message.get(Constants.CONTENT));
} catch (Exception e) {
log.error("问答接口响应体异常:{}", answerStr, e);
// TODO: handle exception
//结果集
Map<String, Object> results = new HashMap<String, Object>(16);
//遍历入库返回结果拼接响应内容
results.put("isLast", 1);
results.put("content", answerStr);
result.put(Constants.RESULTS, JSONObject.toJSONString(results));
result.put(Constants.MESSAGE, "问答异常");
result.put(Constants.STATUS, 2);
task.put(Constants.RESULT, result);
//发送kafka
springBootKafka.send(topic, JSONObject.toJSONString(task));
log.info("数据流转至下游-------");
return;
}
Map<String, Object> results = new HashMap<String, Object>(16);
results.put(Constants.ID, UUID.randomUUID().toString());
results.put(Constants.CONTENT, chatContent.toString());
if (fieldType != 0) {
results.remove(Constants.CONTENT);
try {
//请求成功正常解析
Map<String, Object> message = (Map<String, Object>) answer.get(Constants.MESSAGE);
String reponseContent = (String) message.get(Constants.CONTENT);
Map<String, Object> stringObjectMap = GPTResultParseUtil.parseGPTResult(output, reponseContent);
results.putAll(stringObjectMap);
} catch (Exception e) {
log.error("问答接口响应体异常:{}", answerStr, e);
// TODO: handle exception
//结果集
//遍历入库返回结果拼接响应内容
results.put("isLast", 1);
results.put("content", answerStr);
result.put(Constants.RESULTS, JSONObject.toJSONString(results));
result.put(Constants.MESSAGE, "问答异常");
result.put(Constants.STATUS, 2);
task.put(Constants.RESULT, result);
//发送kafka
springBootKafka.send(topic, JSONObject.toJSONString(task));
log.info("数据流转至下游-------");
return;
}
}
results.put("isLast", 1);
result.put(Constants.RESULTS, JSONObject.toJSONString(results));
result.put(Constants.MESSAGE, "成功");
result.put(Constants.STATUS, 1);
task.put(Constants.RESULT, result);
//发送kafka
springBootKafka.send(topic, JSONObject.toJSONString(task));
log.info("数据流转至下游-------");
} catch (Throwable e) {
log.error("问答处理异常,", e);
//结果集
Map<String, Object> results = new HashMap<String, Object>(16);
Map<String, Object> result = new HashMap<String, Object>(16);
//遍历入库返回结果拼接响应内容
results.put("isLast", 1);
results.put("content", e.getMessage());
result.put(Constants.RESULTS, JSONObject.toJSONString(results));
result.put(Constants.MESSAGE, "异常");
result.put(Constants.STATUS, 2);
task.put(Constants.RESULT, result);
//发送kafka
springBootKafka.send(topic, JSONObject.toJSONString(task));
log.info("数据流转至下游-------");
}
}
/**
* 问题请求
*
@ -178,6 +286,37 @@ public class QandATaskServiceImpl implements QandATaskService {
String html = DownLoadUtil.doPost(Constants.DEEPSEEK_CHAT_URL, JSONObject.toJSONString(params), headers);
return html;
}
/**
* 本地模型请求
* @param temperature
* @param topP
* @param prompt
* @param data
* @return
*/
private String localQandaRequest(Float temperature, Float topP, List<Map<String, Object>> prompt, Map<String, Object> data) {
//新建聊天话术
StringBuffer chatContent = new StringBuffer();
for (Map<String, Object> map : prompt) {
if (Integer.valueOf(map.get(Constants.TYPE).toString()) == Constants.CHAT_TYPE_ONE) {
chatContent.append(map.get(Constants.VALUE));
} else {
String jsonPath = (String) map.get(Constants.VALUE);
chatContent.append(DataUtil.getValue(jsonPath, data).toString());
}
}
Map<String, Object> headers = new HashMap<String, Object>(16);
Map<String, Object> params = new HashMap<String, Object>(16);
List<Map<String, Object>> prompts = buildParam(chatContent.toString(), data);
params.put(Constants.MESSAGES, prompts);
params.put(Constants.TEMPERATURE, temperature);
params.put(Constants.TOP_P, topP);
params.put(Constants.MAX_TOKENS, 512);
headers.put("Content-Type", "application/json");
headers.put("X-API-Key", localQaKey);
String html = DownLoadUtil.doPost(localModelUrl, JSONObject.toJSONString(params), headers);
return html;
}
/**
* 参数构建

33
src/main/java/com/bw/qanda/service/impl/QandaServiceImpl.java

@ -53,5 +53,36 @@ public class QandaServiceImpl implements QandaService {
return JSONObject.toJSONString(response);
}
/**
*本地问答队列搭建
*/
@Override
public String putLocalQuestion(String dataJson) {
Map<String, Object> response = new HashMap<>(16);
int code = 200;
String message = "success";
Map<String, Object> task = null;
try {
task = JSONObject.parseObject(dataJson);
} catch (Exception e) {
log.error("参数结构不合法,", e);
code = 100010;
message = "参数不合法";
}
// 写入队列
try {
if(task.containsKey(Constants.TRACE) && (boolean)task.get(Constants.TRACE)){
ConfigCache.localTaskQueue.putFirst(task);
}else{
ConfigCache.localTaskQueue.put(task);
}
} catch (InterruptedException e) {
log.error("任务写入等待队列异常,", e);
code = 100011;
message = "任务写入等待队列失败";
}
response.put(Constants.CODE, code);
response.put(Constants.MESSAGE, message);
return JSONObject.toJSONString(response);
}
}

2
src/main/java/com/bw/qanda/utils/DownLoadUtil.java

@ -295,7 +295,7 @@ public class DownLoadUtil {
public static String doPost(String url, String params, Map<String, Object>... headers){
String strResult = "";
//设置超时时间
int timeout = 60;
int timeout = 120;
RequestConfig config = RequestConfig.custom().
setConnectTimeout(timeout * 1000).
setConnectionRequestTimeout(timeout * 1000).

10
src/main/java/com/bw/qanda/utils/GPTResultParseUtil.java

@ -1,6 +1,9 @@
package com.bw.qanda.utils;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import lombok.extern.slf4j.Slf4j;
import com.alibaba.fastjson.JSONException;
import java.util.HashMap;
import java.util.Map;
@ -14,6 +17,7 @@ import java.util.regex.Pattern;
* @description:
* @Date:2024/6/28 10:11
*/
@Slf4j
public class GPTResultParseUtil {
public static Map<String, Object> parseGPTResult(Map<String, Object> output, String gptContent) {
Map<String, Object> jsonResult = new HashMap<>();
@ -34,7 +38,9 @@ public class GPTResultParseUtil {
Pattern pattern = Pattern.compile("\\{.*\\}", Pattern.DOTALL);
Matcher matcher = pattern.matcher(gptContent.replace("\n", ""));
if (matcher.find()) {
JSONObject jsonGPT = JSON.parseObject(matcher.group());
String reslut = matcher.group();
log.info("匹配json的结果:{}",reslut);
JSONObject jsonGPT = JSON.parseObject(reslut);
for (String key : output.keySet()) {
if (jsonGPT.containsKey(key)) {
jsonResult.put(key, jsonGPT.get(key));
@ -45,7 +51,7 @@ public class GPTResultParseUtil {
return null;
}
} catch (Exception ex) {
ex.printStackTrace();
log.error("匹配json结果失败:",ex);
return null;
}
}

4
src/main/resources/application.yml

@ -92,6 +92,10 @@ customize-kafka:
topic: produce_analyze
task:
task-queue-path: ../data/taskQueue.txt
local-task-queue-path: ../data/localTaskQueue.txt
localModel:
apiKey: deepseek-7b-id
url: http://192.168.2.112:8000/v1/chat
threadPool:
corePoolSize: 2
maximumPoolSize: 5

Loading…
Cancel
Save