首页 > 基础资料 博客日记
基于Azure实现Java访问OpenAI
2023-07-24 16:21:43基础资料围观390次
这篇文章介绍了基于Azure实现Java访问OpenAI,分享给大家做个参考,收藏Java资料网收获更多编程知识
之前使用了Java代码访问OpenAI:OpenAI注册以及Java代码调用_雨欲语的博客-CSDN博客但是需要vpn才能访问,现在可以基于微软的Azure访问OpenAI,不再需要vpn,官方文档:快速入门 - 开始通过 Azure OpenAI 服务使用 ChatGPT 和 GPT-4 - Azure OpenAI Service | Microsoft Learn,官方对Python和C#进行了封装,java没有,但是可以通过uri的方式进行访问。
Azure申请:什么是 Azure OpenAI 服务? - Azure Cognitive Services | Microsoft Learn
首先根据返回结果可以封装一些java类:
AzureAIChatResponse类:
public class AzureAIChatResponse {
private String id;
private String object;
private String created;
private String model;
private AzureAIUsage usage;
private List<AzureAIChoice> choices;
public String getId() {
return id;
}
public void setId(String id) {
this.id = id;
}
public String getObject() {
return object;
}
public void setObject(String object) {
this.object = object;
}
public String getCreated() {
return created;
}
public void setCreated(String created) {
this.created = created;
}
public String getModel() {
return model;
}
public void setModel(String model) {
this.model = model;
}
public AzureAIUsage getUsage() {
return usage;
}
public void setUsage(AzureAIUsage usage) {
this.usage = usage;
}
public List<AzureAIChoice> getChoices() {
return choices;
}
public void setChoices(List<AzureAIChoice> choices) {
this.choices = choices;
}
}
AzureAIUsage类:
public class AzureAIUsage {
/*
"prompt_tokens": 10,
"completion_tokens": 9,
"total_tokens": 19
*/
@SerializedName("prompt_tokens")
private int promptTokens;
@SerializedName("completion_tokens")
private int completionTokens;
@SerializedName("total_tokens")
private int totalTokens;
public int getPromptTokens() {
return promptTokens;
}
public void setPromptTokens(int promptTokens) {
this.promptTokens = promptTokens;
}
public int getCompletionTokens() {
return completionTokens;
}
public void setCompletionTokens(int completionTokens) {
this.completionTokens = completionTokens;
}
public int getTotalTokens() {
return totalTokens;
}
public void setTotalTokens(int totalTokens) {
this.totalTokens = totalTokens;
}
}
AzureAIChoice类:
public class AzureAIChoice {
private Message message;
}
AzureAIMessage类:
public class AzureAIMessage {
private String role;
private String content;
}
根据参数封装类:
AzureAIChatRequest类:
public class AzureAIChatRequest {
private List<AzureAIMessage> messages;
private Double temperature;
@SerializedName("n")
private Integer choices;
private boolean stream;
private String stop;
@SerializedName("max_tokens")
private Integer maxTokens;
@SerializedName("presence_penalty")
private Integer presencePenalty;
@SerializedName("frequency_penalty")
private Integer frequencyPenalty;
private String user;
public List<AzureAIMessage> getMessages() {
return messages;
}
public void setMessages(List<AzureAIMessage> messages) {
this.messages = messages;
}
public void addMessage(AzureAIMessage message) {
if (this.messages == null) {
this.messages = new ArrayList<>();
}
this.messages.add(message);
}
public Double getTemperature() {
return temperature;
}
public void setTemperature(Double temperature) {
this.temperature = temperature;
}
public int getChoices() {
return choices;
}
public void setChoices(int choices) {
this.choices = choices;
}
public boolean isStream() {
return stream;
}
public void setStream(boolean stream) {
this.stream = stream;
}
public String isStop() {
return stop;
}
public void setStop(String stop) {
this.stop = stop;
}
public void setStop(boolean stop) {
if (stop) {
this.stop = "true";
} else {
this.stop = "false";
}
}
public int getMaxTokens() {
return maxTokens;
}
public void setMaxTokens(int maxTokens) {
this.maxTokens = maxTokens;
}
public int getPresencePenalty() {
return presencePenalty;
}
public void setPresencePenalty(int presencePenalty) {
this.presencePenalty = presencePenalty;
}
public int getFrequencyPenalty() {
return frequencyPenalty;
}
public void setFrequencyPenalty(int frequencyPenalty) {
this.frequencyPenalty = frequencyPenalty;
}
public String getUser() {
return user;
}
public void setUser(String user) {
this.user = user;
}
}
AzureAIMessage类:
public class AzureAIMessage {
private AzureAIRole role;
private String content;
public AzureAIMessage() {
}
public AzureAIMessage(String content, AzureAIRole role) {
this.content = content;
this.role = role;
}
public AzureAIRole getRole() {
return role;
}
public void setRole(AzureAIRole role) {
this.role = role;
}
public String getContent() {
return content;
}
public void setContent(String content) {
this.content = content;
}
}
AzureAIRole类:
public enum AzureAIRole {
@SerializedName("assistant")
ASSISTANT("assistant"),
@SerializedName("system")
SYSTEM("system"),
@SerializedName("user")
USER("user"),
;
private final String text;
private AzureAIRole(final String text) {
this.text = text;
}
@Override
public String toString() {
return text;
}
}
客户端访问类:
import cn.hutool.core.date.BetweenFormatter;
import cn.hutool.core.date.DateUnit;
import cn.hutool.core.date.DateUtil;
import cn.hutool.http.HttpRequest;
import cn.hutool.json.JSONUtil;
import com.google.gson.Gson;
import lombok.extern.slf4j.Slf4j;
import org.asynchttpclient.*;
import java.io.Closeable;
import java.io.IOException;
import java.util.Date;
import java.util.concurrent.Future;
@Slf4j
public class AzureAIClient implements Closeable {
private static final String JSON = "application/json; charset=UTF-8";
private final boolean closeClient;
private final AsyncHttpClient client;
private final String deploymentName;
private final String url;
private final String token;
private static final Version version = new Version();
private final String apiVersion;
private boolean closed = false;
Gson gson = new Gson();
public AzureAIClient(String url, String apiKey, String deploymentName, String apiVersion) throws Exception {
this.client = new DefaultAsyncHttpClient();
this.url = url + "/openai/deployments/" + deploymentName + "/";
this.token = apiKey;
this.deploymentName = deploymentName;
this.apiVersion = apiVersion;
closeClient = true;
}
public boolean isClosed() {
return closed || client.isClosed();
}
@Override
public void close() {
if (closeClient && !client.isClosed()) {
try {
client.close();
} catch (IOException ex) {
}
}
closed = true;
}
public static String getVersion() {
return version.getBuildNumber();
}
public static String getBuildName() {
return version.getBuildName();
}
public AzureAICompletionsResult getCompletion(AzureAICompletionRequest completion) throws Exception {
//chat/completions
Future<Response> f = client.executeRequest(buildRequest("POST", "completions?api-version=" + apiVersion, gson.toJson(completion)));
Response r = f.get();
if (r.getStatusCode() != 200) {
throw new Exception("Could not get competion result");
} else {
return gson.fromJson(r.getResponseBody(), AzureAICompletionsResult.class);
}
}
public AzureAICreateEmbedingResponse createEmbedding(AzureAIEmbedding embedding) throws Exception {
Future<Response> f = client.executeRequest(buildRequest("POST", "embeddings?api-version=" + apiVersion, gson.toJson(embedding)));
Response r = f.get();
if (r.getStatusCode() != 200) {
throw new Exception("Could not create embedding");
} else {
AzureAICreateEmbedingResponse azureAICreateEmbedingResponse = JSONUtil.toBean(r.getResponseBody(), AzureAICreateEmbedingResponse.class);
return azureAICreateEmbedingResponse;
}
}
public AzureAIChatResponse sendMyChatRequest(AzureAIChatRequest chatRequest) throws Exception {
Date startDateOne = DateUtil.date();
String f = buildMyRequest("POST", "chat/completions?api-version=" + apiVersion, gson.toJson(chatRequest));
Date endDateOne = DateUtil.date();
// 获取开始时间和结束时间的时间差
long betweenDateOne = DateUtil.between(startDateOne, endDateOne, DateUnit.MS);
// 格式化时间
String formatBetweenOne = DateUtil.formatBetween(betweenDateOne, BetweenFormatter.Level.MILLISECOND);
log.info(String.format("请求数据耗时(毫秒):%s",formatBetweenOne));
Date startDate = DateUtil.date();
System.err.println(f);
AzureAIChatResponse azureAIChatResponse = gson.fromJson(f, AzureAIChatResponse.class);
Date endDate = DateUtil.date();
// 获取开始时间和结束时间的时间差
long betweenDate = DateUtil.between(startDate, endDate, DateUnit.MS);
// 格式化时间
String formatBetween = DateUtil.formatBetween(betweenDate, BetweenFormatter.Level.MILLISECOND);
log.info(String.format("格式化数据耗时(毫秒):%s",formatBetween));
return azureAIChatResponse;
}
private String buildMyRequest(String type, String subUrl, String requestBody) {
// RestTemplate restTemplate = new RestTemplate();
// HttpHeaders httpHeaders = new HttpHeaders();
// // 设置contentType
httpHeaders.setContentType(MediaType.APPLICATION_JSON_UTF8);
// // 给请求header中添加一些数据
// httpHeaders.add("Accept", JSON);
// httpHeaders.add("Content-Type", JSON);
// httpHeaders.add("api-key", this.token);
//
//
// HttpEntity<String> httpEntity = new HttpEntity<String>(requestBody, httpHeaders);
// ResponseEntity<String> exchange = restTemplate.postForEntity(this.url + subUrl, httpEntity, String.class);
//
// String resultRemote = exchange.getBody();//得到返回的值
String accept = HttpRequest.post(this.url + subUrl)
.header("Accept", JSON)
.header("Content-Type", "application/json")
.header("api-key", this.token)
.setReadTimeout(30000)
.body(requestBody)
.execute()
.body();
return accept;
}
public AzureAIChatResponse sendChatRequest(AzureAIChatRequest chatRequest) throws Exception {
Date startDateOne = DateUtil.date();
Future<Response> f = client.executeRequest(buildRequest("POST", "chat/completions?api-version=" + apiVersion, gson.toJson(chatRequest)));
// Request r = buildRequest("POST", "chat/completions?api-version=" + apiVersion, gson.toJson(chatRequest));
Response r = f.get();
Date endDateOne = DateUtil.date();
// 获取开始时间和结束时间的时间差
long betweenDateOne = DateUtil.between(startDateOne, endDateOne, DateUnit.MS);
// 格式化时间
String formatBetweenOne = DateUtil.formatBetween(betweenDateOne, BetweenFormatter.Level.MILLISECOND);
log.info(String.format("请求数据耗时(毫秒):%s",formatBetweenOne));
if (r.getStatusCode() != 200) {
log.info("Could not create chat request - server resposne was " + r.getStatusCode() + " to url: " + url + "chat/completions?api-version=2023-03-15-preview");
return null;
} else {
Date startDate = DateUtil.date();
// System.err.println(r.getResponseBody());
AzureAIChatResponse azureAIChatResponse = JSONUtil.toBean(r.getResponseBody(), AzureAIChatResponse.class);
// AzureAIChatResponse azureAIChatResponse = gson.fromJson(r.getResponseBody(), AzureAIChatResponse.class);
Date endDate = DateUtil.date();
// 获取开始时间和结束时间的时间差
long betweenDate = DateUtil.between(startDate, endDate, DateUnit.MS);
// 格式化时间
String formatBetween = DateUtil.formatBetween(betweenDate, BetweenFormatter.Level.MILLISECOND);
log.info(String.format("格式化数据耗时(毫秒):%s",formatBetween));
return azureAIChatResponse;
}
}
private Request buildRequest(String type, String subUrl) {
RequestBuilder builder = new RequestBuilder(type);
Request request = builder.setUrl(this.url + subUrl)
.addHeader("Accept", JSON)
.addHeader("Content-Type", JSON)
.addHeader("Authorization", "Bearer " + this.token)
.build();
return request;
}
private Request buildRequest(String type, String subUrl, String requestBody) {
RequestBuilder builder = new RequestBuilder(type);
Request request = builder.setUrl(this.url + subUrl)
.addHeader("Accept", JSON)
.addHeader("Content-Type", JSON)
.addHeader("api-key", this.token)
.setBody(requestBody)
.build();
return request;
}
}
调用测试:
public static void main(String[] args) {
// 装配请求集合
List<AzureAIMessage> azureAiMessageList = new ArrayList<>();
AzureAIChatRequest azureAiChatRequest = new AzureAIChatRequest();
AzureAIMessage azureAIMessage0 = new AzureAIMessage();
azureAIMessage0.setRole(AzureAIRole.SYSTEM);
azureAIMessage0.setContent("你是一个AI机器人,请根据提问进行回答");
azureAiMessageList.add(azureAIMessage0);
AzureAIMessage azureAIMessage1 = new AzureAIMessage();
azureAIMessage1.setRole(AzureAIRole.USER);
azureAIMessage1.setContent("请解释一下java的gc");
azureAiMessageList.add(azureAIMessage1);
azureAiChatRequest.setMessages(azureAiMessageList);
azureAiChatRequest.setMaxTokens(maxTokens);
azureAiChatRequest.setTemperature(temperature);
// 是否进行留式返回
// azureAiChatRequest.setStream(true);
azureAiChatRequest.setPresencePenalty(0);
azureAiChatRequest.setFrequencyPenalty(0);
azureAiChatRequest.setStop(null);
AzureAIClient azureAIClient = new AzureAIClient("申请的azure地址", "zaure的apikey",
"模型(gpt-35-turbo)", "api版本:(023-03-15-preview)");
AzureAIChatResponse azureAIChatResponse = azureAIClient.sendChatRequest(azureAIChatRequest);
}
maven依赖:
<dependencies>
<dependency>
<groupId>org.asynchttpclient</groupId>
<artifactId>async-http-client</artifactId>
<version>2.12.3</version>
<type>jar</type>
</dependency>
<dependency>
<groupId>com.google.code.gson</groupId>
<artifactId>gson</artifactId>
<version>2.10.1</version>
</dependency>
</dependencies>
文章来源:https://blog.csdn.net/qq_41061437/article/details/130927618
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:jacktools123@163.com进行投诉反馈,一经查实,立即删除!
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:jacktools123@163.com进行投诉反馈,一经查实,立即删除!
标签: