1 changed files with 144 additions and 0 deletions
@ -0,0 +1,144 @@ |
|||
package com.bw; |
|||
import org.apache.http.HttpEntity; |
|||
import org.apache.http.HttpResponse; |
|||
import org.apache.http.StatusLine; |
|||
import org.apache.http.client.HttpClient; |
|||
import org.apache.http.client.config.RequestConfig; |
|||
import org.apache.http.client.methods.HttpPost; |
|||
import org.apache.http.config.SocketConfig; |
|||
import org.apache.http.entity.StringEntity; |
|||
import org.apache.http.impl.client.HttpClientBuilder; |
|||
import org.apache.http.util.EntityUtils; |
|||
import com.alibaba.fastjson.JSONObject; |
|||
|
|||
import java.util.ArrayList; |
|||
import java.util.HashMap; |
|||
import java.util.List; |
|||
import java.util.Map; |
|||
|
|||
public class EnhancedApiClient { |
|||
|
|||
/** |
|||
* 问题请求 |
|||
* |
|||
* @param authorization |
|||
* @param temperature |
|||
* @param topP |
|||
* @param prompt |
|||
* @return |
|||
*/ |
|||
private static String qandaRequest(String apiUrl,String authorization, String model, Float temperature, Float topP,String question) { |
|||
Map<String, Object> headers = new HashMap<String, Object>(16); |
|||
Map<String, Object> params = new HashMap<String, Object>(16); |
|||
List<Map<String, Object>> prompts = buildParam(question); |
|||
params.put("model", model); |
|||
params.put("messages", prompts); |
|||
params.put("temperature", temperature); |
|||
params.put("top_p", topP); |
|||
params.put("stream", false); |
|||
params.put("max_tokens", 8191); |
|||
headers.put("authorization", authorization); |
|||
headers.put("Content-Type", "application/json"); |
|||
String html = doPost(apiUrl, JSONObject.toJSONString(params), headers); |
|||
return html; |
|||
} |
|||
/** |
|||
* json参数方式POST提交 |
|||
* @param url |
|||
* @param params |
|||
* @return |
|||
*/ |
|||
public static String doPost(String url, String params, Map<String, Object>... headers){ |
|||
String strResult = ""; |
|||
//设置超时时间 |
|||
int timeout = 120; |
|||
RequestConfig config = RequestConfig.custom(). |
|||
setConnectTimeout(timeout * 1000). |
|||
setConnectionRequestTimeout(timeout * 1000). |
|||
setSocketTimeout(timeout * 1000).build(); |
|||
SocketConfig socketConfig = SocketConfig.custom() |
|||
.setSoKeepAlive(false) |
|||
.setSoLinger(1) |
|||
.setSoReuseAddress(true) |
|||
.setSoTimeout(timeout * 1000) |
|||
.setTcpNoDelay(true).build(); |
|||
// AuthCache authCache = new BasicAuthCache(); |
|||
// authCache.put(proxy, new BasicScheme()); |
|||
// HttpClientContext localContext = HttpClientContext.create(); |
|||
// localContext.setAuthCache(authCache); |
|||
// 1. 获取默认的client实例 |
|||
HttpClientBuilder httpBuilder = HttpClientBuilder.create(); |
|||
HttpClient client = httpBuilder.setDefaultSocketConfig(socketConfig).setDefaultRequestConfig(config).build(); |
|||
// HttpClient client = httpBuilder.setDefaultSocketConfig(socketConfig).setDefaultRequestConfig(config).setConnectionManager(cm) |
|||
// .setDefaultCredentialsProvider(credsProvider).build(); |
|||
// 2. 创建httppost实例 |
|||
HttpPost httpPost = new HttpPost(url); |
|||
// httpPost.setConfig(reqConfig); |
|||
if (headers != null && headers.length > 0) { |
|||
Map<String, Object> tempHeaders = headers[0]; |
|||
for (String key : tempHeaders.keySet()) { |
|||
httpPost.setHeader(key, tempHeaders.get(key).toString()); |
|||
} |
|||
} else { |
|||
httpPost.addHeader("Content-Type", "application/json;charset=utf-8"); |
|||
} |
|||
HttpResponse resp = null; |
|||
try { |
|||
httpPost.setEntity(new StringEntity(params,"utf-8")); |
|||
resp = client.execute(httpPost); |
|||
// resp = client.execute(httpPost,localContext); |
|||
StatusLine statusLine = resp.getStatusLine(); |
|||
System.out.println("响应状态为:" + resp.getStatusLine()); |
|||
int notFundCode = 404; |
|||
int successCode = 200; |
|||
if(statusLine.getStatusCode() == successCode){ |
|||
// 7. 获取响应entity |
|||
HttpEntity respEntity = resp.getEntity(); |
|||
strResult = EntityUtils.toString(respEntity, "UTF-8"); |
|||
if(strResult.equals("")){ |
|||
strResult = "Download failed error is:reslut is null"; |
|||
} |
|||
}else{ |
|||
throw new Exception("请求错误,code码为:"+statusLine.getStatusCode()); |
|||
} |
|||
} catch (Exception e) { |
|||
e.printStackTrace(); |
|||
} |
|||
return strResult; |
|||
} |
|||
/** |
|||
* 参数构建 |
|||
* |
|||
* @param prompt |
|||
* @param data |
|||
* @return |
|||
*/ |
|||
private static List<Map<String, Object>> buildParam(String prompt) { |
|||
//聊天体 |
|||
List<Map<String, Object>> prompts = new ArrayList<Map<String, Object>>(); |
|||
Map<String, Object> chat1 = new HashMap<String, Object>(16); |
|||
chat1.put("role", "user"); |
|||
chat1.put("content", prompt); |
|||
prompts.add(chat1); |
|||
return prompts; |
|||
} |
|||
|
|||
public static void main(String[] args) { |
|||
String apiUrl = "https://api.deepseek.com/chat/completions"; |
|||
String apiKey = "Bearer sk-ed18c7b7c68d4d0f9ae315aaba1fb2ad"; |
|||
String model = "deepseek-chat"; |
|||
|
|||
String question = "我获取的数据为\"美利坚合众国马里兰州德特里克堡美国陆军传染病医学研究所病理科。\"我的机构数据库为\"[美国陆军传染病医学研究所,武汉研究所,美国德克萨斯生物医学研究所1\",用我获取的数据在我的库中筛查,匹配到符合的机构,依据相近以及你觉得名字不同但是为同一机构,把我的库中的名字输出给我,只输出机构名字,不要输出额外的内容,如果没有找到,输出无匹配内容。"; |
|||
|
|||
try { |
|||
Float temperature = 0.3f; |
|||
float topP = 0.3f; |
|||
String answer = qandaRequest(apiUrl,apiKey, model, temperature, topP,question); |
|||
System.out.println("\nAPI回答:"); |
|||
System.out.println(answer); |
|||
} catch (Exception e) { |
|||
System.err.println("API调用失败: " + e.getMessage()); |
|||
e.printStackTrace(); |
|||
} |
|||
} |
|||
} |
Write
Preview
Loading…
Cancel
Save
Reference in new issue