参考文档:https://help.aliyun.com/zh/dashscope/developer-reference/quick-start

一、创建API-KEY

控制台地址:https://dashscope.console.aliyun.com/apiKey

二、使用java SDK接入

全量输出:根据用户的prompt,一次性将结果输出给用户

增量输出:根据用户的prompt,一个字一个字的输出给用户,类似打字机的效果

由于全量输出比较简单,这里只介绍增量输出,此时会用到SSE的技术,不懂的可以百度,这里不做详细介绍。

2.1、引入依赖

新建一个spring-boot项目引入大模型的sdk

<dependency>
    <groupId>com.alibaba</groupId>
    <artifactId>dashscope-sdk-java</artifactId>
    <version>2.9.0</version>
</dependency>

增量输出使用到的SseEmitterspring-webmvc包中已存在,不需要额外添加。

2.2、后端实现

2.2.1、SseController

package com.example.tyqw;

import com.alibaba.cola.dto.SingleResponse;
import com.google.common.base.Throwables;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;

import javax.annotation.Resource;
import java.util.concurrent.CompletableFuture;

@Slf4j
@RestController
@RequestMapping("sse")
public class SseController {
    
    @Resource
    private SseService sseService;

    @GetMapping(value = "connect/{clientId}")
    public SseEmitter connect(@PathVariable("clientId") String clientId) {
        final SseEmitter emitter = sseService.getConn(clientId);
        return emitter;
    }

    @GetMapping(value = "send/{clientId}")
    public SingleResponse send(@PathVariable("clientId") String clientId, @RequestParam String message) {
        CompletableFuture.runAsync(() -> {
            try {
                sseService.send(clientId, message);
            } catch (Exception e) {
                log.error("推送数据异常-{}", Throwables.getStackTraceAsString(e));
            }
        });
        return SingleResponse.buildSuccess();
    }

    @GetMapping("close/{clientId}")
    public SingleResponse close(@PathVariable("clientId") String clientId) {
        sseService.closeConn(clientId);
        log.info("===clientId-{}-连接已关闭",clientId);
        return SingleResponse.of("连接已关闭");
    }

    @GetMapping("img")
    public String img(@RequestParam String message) {
        log.info("====>生成图片-prompt-处理中-{}",message);
        return sseService.img(message);
    }

}

2.2.2、SseService以及实现类

package com.example.tyqw;

import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;

public interface SseService {
    SseEmitter getConn(String clientId);

    void send(String clientId, String message);

    void closeConn(String clientId);

    String img(String message);
}

package com.example.tyqw;

import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;

import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;

@Slf4j
@Service
public class SseServiceImpl implements SseService {

    private static final Map<String, SseEmitter> SSE_CACHE = new ConcurrentHashMap<>();


    @Override
    public SseEmitter getConn(String clientId) {
        final SseEmitter sseEmitter = SSE_CACHE.get(clientId);

        if (sseEmitter != null) {
            return sseEmitter;
        } else {
            // 设置连接超时时间,需要配合配置项 spring.mvc.async.request-timeout: 600000 一起使用
            final SseEmitter emitter = new SseEmitter(3600*1000L);
            // 注册超时回调,超时后触发
            emitter.onTimeout(() -> {
                log.info("连接已超时,正准备关闭,clientId = "+clientId);
                SSE_CACHE.remove(clientId);
            });
            //处理完回调,调用 emitter.complete() 触发
            emitter.onCompletion(() -> {
                log.info("处理已完成,clientId = "+clientId);
            });
            // 注册异常回调,调用 emitter.completeWithError() 触发
            emitter.onError(throwable -> {
                log.info("连接已异常,正准备关闭,clientId = "+ clientId+"==>"+ throwable);
                SSE_CACHE.remove(clientId);
            });

            SSE_CACHE.put(clientId, emitter);
            return emitter;
        }
    }

    /**
     * 模拟类似于 chatGPT 的流式推送回答
     * @param clientId 客户端 id
     */
    @Override
    public void send(String clientId, String message) {
        final SseEmitter emitter = SSE_CACHE.get(clientId);
        if (Objects.nonNull(emitter)) {
            QianWenUtil.getFlowAnswer(emitter,message);
        }
        else{
            log.error("请刷新页面后重试");
        }

    }

    @Override
    public void closeConn(String clientId) {
        final SseEmitter sseEmitter = SSE_CACHE.get(clientId);
        if (sseEmitter != null) {
            sseEmitter.complete();
        }
    }

    @Override
    public String img(String message) {
        return QianWenUtil.getImageBase64(message);
    }
}

2.2.3、QianWenUtil

package com.example.tyqw;

import cn.hutool.core.io.IoUtil;
import com.alibaba.dashscope.aigc.conversation.Conversation;
import com.alibaba.dashscope.aigc.conversation.ConversationParam;
import com.alibaba.dashscope.aigc.conversation.ConversationResult;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisParam;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisResult;
import com.alibaba.dashscope.exception.ApiException;
import com.alibaba.dashscope.exception.InputRequiredException;
import com.alibaba.dashscope.exception.NoApiKeyException;
import com.google.common.base.Throwables;
import io.reactivex.Flowable;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;

import java.io.IOException;
import java.net.URL;
import java.util.Base64;

import static com.example.tyqw.QwConstants.API_KEY;

/**
 * @author DUCHONG
 * @date 2023-09-14 17:29
 **/
@Slf4j
public class QianWenUtil {


    /**
     * 全量输出
     * @param prompt
     * @return
     */
    public static String getAnswer(String prompt){
        try {
            Conversation conversation = new Conversation();

            ConversationParam param = ConversationParam
                    .builder()
                    .model(Conversation.Models.QWEN_TURBO)
                    .apiKey(API_KEY)
                    .prompt(prompt)
                    .build();
            ConversationResult result = conversation.call(param);
            return result.getOutput().getText();
        } catch (ApiException | NoApiKeyException | InputRequiredException e) {
            System.out.println(e.getMessage());
        }
        return null;
    }

    /**
     * 增量输出-打字机效果
     * @param sseEmitter
     * @param prompt
     */
    public static void getFlowAnswer(SseEmitter sseEmitter,String prompt){
        try {
            Conversation conversation = new Conversation();
            ConversationParam param = ConversationParam
                    .builder()
                    .model(Conversation.Models.QWEN_MAX)
                    .apiKey(API_KEY)
                    .prompt(prompt)
                    .incrementalOutput(Boolean.TRUE)
                    .build();
            Flowable<ConversationResult> result = conversation.streamCall(param);
            result.blockingForEach(message -> {
                //log.info("message====>{}", JSON.toJSONString(message));
                String outPut = message.getOutput().getText();
                if (("stop").equals(message.getOutput().getFinishReason())) {
                    sseEmitter.send("DONE");
                }
                else {
                    outPut = outPut.replaceAll("\\n", "</br>");
                    sseEmitter.send(outPut);
                }
            });
        } catch (ApiException | NoApiKeyException | InputRequiredException e) {
            log.error("===>发送流式结果ERROR-{}", Throwables.getStackTraceAsString(e));
        }
        finally {
            //结束推流
            //sseEmitter.complete();
        }
    }

    /**
     * 生成图片的base64编码
     * @param prompt
     * @return
     */
    public static String getImageBase64(String prompt){
        try {
            ImageSynthesis is = new ImageSynthesis();
            ImageSynthesisParam param =
                    ImageSynthesisParam.builder()
                            .model(ImageSynthesis.Models.WANX_V1)
                            .n(1)
                            .size("1024*1024")
                            .apiKey(API_KEY)
                            .prompt(prompt)
                            .build();

            ImageSynthesisResult result = is.call(param);
            String url = result.getOutput().getResults().get(0).get("url");
            byte[] urlByteArray = IoUtil.readBytes(new URL(url).openStream());
            String base64Encoded = Base64.getEncoder().encodeToString(urlByteArray);
            log.info("===>生成图片完成-RUL-{}",url);
            return base64Encoded;
        } catch (NoApiKeyException | IOException e) {
            log.error("===>生成图片失败-ERROR-{}", Throwables.getStackTraceAsString(e));
        }
        return null;
    }
}

2.3、前段页面

2.3.1、chat.html

<!DOCTYPE html>
<header>
    <meta charset="UTF-8">
    <title>通义千问</title>
</header>
<body>
    <h3>请输入问题:</h3>
	<textarea id="prompt">

	</textarea>
    <button onclick="prompt()">发送</button>
    <h3><pre id="message"></pre></h3>
</body>
<script>

	var source = null;
    // 用时间戳模拟登录用户
    const userId = new Date().getTime();

    function prompt() {
        const message = document.getElementById('prompt').value
        const httpRequest = new XMLHttpRequest();
        httpRequest.open('GET', 'http://localhost:8080/sse/send/' + userId +"?message=" + message, true);
        httpRequest.send();
    }

    if (!!window.EventSource) {

        // 建立连接
        source = new EventSource('http://localhost:8080/sse/connect/' + userId);

        /**
         * 连接一旦建立,就会触发open事件
         * 另一种写法:source.onopen = function (event) {}
         */
        source.addEventListener('open', function (e) {
            console.log("连接SSE成功");
        }, false);

        /**
         * 客户端收到服务器发来的数据
         * 另一种写法:source.onmessage = function (event) {}
         */
        source.addEventListener('message', function (e) {
            setMessageInnerHTML(e.data);
        });


        /**
         * 如果发生通信错误(比如连接中断),就会触发error事件
         * 或者:
         * 另一种写法:source.onerror = function (event) {}
         */
        source.addEventListener('error', function (e) {
            if (e.readyState === EventSource.CLOSED) {
                console.log("连接关闭");
            } else {
                console.log(e);
            }
        }, false);

    } else {
        setMessageInnerHTML("你的浏览器不支持SSE");
    }

    // 监听窗口关闭事件,主动去关闭sse连接,如果服务端设置永不过期,浏览器关闭后手动清理服务端数据
    window.onbeforeunload = function () {
        closeSse();
    };

    // 关闭Sse连接
    function closeSse() {
        source.close();
        const httpRequest = new XMLHttpRequest();
        httpRequest.open('GET', 'http://localhost:8080/sse/close/' + userId, true);
        httpRequest.send();
        console.log("close");
    }

    // 将消息显示在网页上
    function setMessageInnerHTML(innerHTML) {
        document.getElementById('message').innerHTML += innerHTML;
    }

</script>

</html>

2.4、结果演示

声明:本站所有文章,如无特殊说明或标注,均为本站原创发布。任何个人或组织,在未征得本站同意时,禁止复制、盗用、采集、发布本站内容到任何网站、书籍等各类媒体平台。如若本站内容侵犯了原著者的合法权益,可联系我们进行处理。