Commit 930af0d7 authored by anShuang's avatar anShuang
Browse files

commit

No related merge requests found
Showing with 0 additions and 356 deletions
+0 -356
package diffusionAiService
import (
"aiapi/models/req"
"aiapi/repositories"
"bufio"
"bytes"
"encoding/json"
"fmt"
"github.com/sirupsen/logrus"
"io"
"mime/multipart"
"net/http"
"os"
"strconv"
)
var DiffusionAiCha chan PubStruct
type PubStruct struct {
ApiName string `json:"apiName"`
InStruct interface{}
requestId int64 `json:"requestId"`
}
type DiffusionRes struct {
RequestId int64 `structs:"request_id" json:"requestId"`
}
type ImagesStruct struct {
Images []string `json:"images"`
}
// 获取空闲的AI算力机器,如果无空闲返回""
func queryComputePower() map[string]string {
dataReq := make(map[string]interface{})
computePower := make(map[string]string)
dataReq["status"] = int64(1)
var queryAiComputePower = new(repositories.TbAiComputePower)
dataRes := queryAiComputePower.Query(dataReq)
if len(dataRes) > 0 {
computePower["baseUrl"] = dataRes[0].BaseUrl
computePower["id"] = strconv.FormatInt(dataRes[0].Id, 10)
}
return computePower
}
// 更新算力状态
func updateComputePower(updateMapReq map[string]int64) (count int64) {
var updateComputerPowerReq = new(repositories.TbAiComputePower)
updateComputerPowerReq.Id, _ = updateMapReq["id"]
updateComputerPowerReq.Status = updateMapReq["status"]
updateComputerPowerRes := updateComputerPowerReq.Update()
return updateComputerPowerRes
}
// 获取userAiRequest请求数据
func queryUserAiRequest(requestId int64) (data repositories.TbUserAIRequest) {
var queryUserAIRequestReq = new(repositories.TbUserAIRequest)
queryUserAIRequestRes := queryUserAIRequestReq.GetById(strconv.FormatInt(requestId, 10))
return queryUserAIRequestRes
}
// 生成userAiRequest请求数据
func addUserAiRequest(addUserAiReq req.UserAIRequestReq) (id int64) {
var addUserAIRequestReq = new(repositories.TbUserAIRequest)
addUserAIRequestReq.UserId = addUserAiReq.UserId
addUserAIRequestReq.Status = addUserAiReq.Status
addUserAIRequestRes, err := addUserAIRequestReq.Add()
logrus.Info("request data = %v\n", addUserAIRequestReq)
if err != nil {
logrus.Info("AddUserAIRequest error = %v\n", err)
return
}
id = addUserAIRequestRes.Id
return
}
// 更新userAiRequest
func updateUserAiRequest(updateUserAiReq map[string]int64) (count int64) {
var updateUserAiRequestReq = new(repositories.TbAiComputePower)
updateUserAiRequestReq.Id, _ = updateUserAiReq["id"]
updateUserAiRequestReq.Status = updateUserAiReq["status"]
updateComputerPowerRes := updateUserAiRequestReq.Update()
return updateComputerPowerRes
}
// 管道
func DiffusionAiChanel() {
for {
select {
case chanReq := <-DiffusionAiCha:
go func(gi PubStruct) {
if chanReq.ApiName == "genImage" {
byteData, err := json.Marshal(gi.InStruct)
if err != nil {
logrus.Info(err.Error())
}
var inStrc GenImageReq
err = json.Unmarshal(byteData, &inStrc)
if err != nil {
logrus.Info(err.Error())
}
GoGenImageLogic(inStrc, gi.requestId)
}
}(chanReq)
}
}
}
// base64上传至服务器
func ImgUpload(fileName string, imgStr string) error {
file, err := os.OpenFile(fileName, os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return err
logrus.Info(err.Error())
}
w := bufio.NewWriter(file)
_, err3 := w.WriteString(string(imgStr))
if err3 != nil {
return err3
logrus.Info(err3.Error())
}
w.Flush()
defer file.Close()
return nil
}
func PostFile(filename string, targetUrl string) (*http.Response, error) {
bodyBuf := bytes.NewBufferString("")
bodyWriter := multipart.NewWriter(bodyBuf)
// use the body_writer to write the Part headers to the buffer
_, err := bodyWriter.CreateFormFile("userfile", filename)
if err != nil {
fmt.Println("error writing to buffer")
return nil, err
}
// the file data will be the second part of the body
fh, err := os.Open(filename)
if err != nil {
fmt.Println("error opening file")
return nil, err
}
// need to know the boundary to properly close the part myself.
boundary := bodyWriter.Boundary()
//close_string := fmt.Sprintf("\r\n--%s--\r\n", boundary)
close_buf := bytes.NewBufferString(fmt.Sprintf("\r\n--%s--\r\n", boundary))
// use multi-reader to defer the reading of the file data until
// writing to the socket buffer.
requestReader := io.MultiReader(bodyBuf, fh, close_buf)
fi, err := fh.Stat()
if err != nil {
fmt.Printf("Error Stating file: %s", filename)
return nil, err
}
req, err := http.NewRequest("POST", targetUrl, requestReader)
if err != nil {
return nil, err
}
// Set headers for multipart, and Content Length
req.Header.Add("Content-Type", "multipart/form-data; boundary="+boundary)
req.ContentLength = fi.Size() + int64(bodyBuf.Len()) + int64(close_buf.Len())
return http.DefaultClient.Do(req)
}
package diffusionAiService
import (
"aiapi/models/req"
"aiapi/repositories"
"aiapi/utils"
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"github.com/kataras/iris/v12"
"github.com/kataras/iris/v12/mvc"
"github.com/sirupsen/logrus"
"io"
"net/http"
"strconv"
"time"
)
type GenImageReq struct {
//用户Id
UserId int64 `structs:"user_id" json:"userId"`
// 提示词(必要)
Prompt string `structs:"prompt" json:"prompt"`
//状态
Status int64 `structs:"status" json:"status"`
// 种子,控制随机性 默认-1
seed int64 `structs:"seed" default:"-1"`
//一次生成几张图片 默认 1
batchSize int64 `structs:"batch_size"`
//提词器相关度 默认7
cfgScale int64 `structs:"cfg_scale"`
//负向提词器 默认 “”
negativePrompt string `structs:"negative_prompt"`
//覆盖默认参数的配置 修改使用的模型{"sd_model_checkpoint": "768-v-ema.ckpt [a2a802b2]"}
overrideSettings struct {
sdModelCheckpoint string `structs:"sd_model_checkpoint"`
}
//图片宽度 512
width int64 `structs:"width"`
//图片高度 512
height int64 `structs:"height"`
}
// GenImageLogic
// @Summary 根据文字生成图片,参数与python一致,返回为一个requestId
// @Description
// @Tags diffusionAiServiceService
// @Accept application/json
// @Produce application/json
// @Param prompt body GenImageReq true "json入参"
// @Success 200
// @Failure 500 "{"message":"fail","code":500,"data":null}"
// @Failure 520 "{"message":"station exists","code":520,"data":null}"
// @Router /diffusionAiServiceService/GenImage [post]
func GenImageLogic(ctx iris.Context) (mr mvc.Response) {
logrus.Printf(ctx.String())
var genImage GenImageReq
body, _ := ctx.GetBody()
err := json.Unmarshal(body, &genImage)
if err != nil {
utils.FailWithMsg(&mr, err)
logrus.Error(err)
return
}
var addUserAiReq req.UserAIRequestReq
addUserAiReq.UserId = genImage.UserId
addUserAiReq.Status = genImage.Status
// add userAiRequest
var res DiffusionRes
res.RequestId = addUserAiRequest(addUserAiReq)
utils.SuccessWithMsg(&mr, "success", res)
var diffusionAiCha PubStruct
diffusionAiCha.ApiName = "genImage"
diffusionAiCha.InStruct = genImage
diffusionAiCha.requestId = res.RequestId
//go to chanel
DiffusionAiCha <- diffusionAiCha
return
}
func GoGenImageLogic(genImage GenImageReq, requestId int64) {
computePower := make(map[string]string)
var userAiReq = new(repositories.TbUserAIRequest)
userAiReq.Id = requestId
// 如果没有空闲的算力,进程一致处于等待中 获取等待时间
//estimateSec等待时间计时
estimateSecStart := time.Now()
for {
computePower = queryComputePower()
if len(computePower) > 0 {
break
}
}
//获取用户请求最新状态,如果状态为已取消,则不调用python服务
userAiRes := queryUserAiRequest(requestId)
if userAiRes.Status == 4 {
return
} else {
//更新算力状态 update
computePowerMap := make(map[string]int64)
id, err := strconv.ParseInt(computePower["id"], 10, 64)
if err != nil {
logrus.Error(err)
}
computePowerMap["id"] = id
computePowerMap["status"] = 2
updateComputePower(computePowerMap)
//更新等待时间以及状态
estimateSec := time.Since(estimateSecStart)
userAiReq.ProcessPowerId, _ = strconv.ParseInt(computePower["id"], 10, 64)
userAiReq.EstimateSec = int64(estimateSec)
userAiReq.Status = 2
_ = userAiReq.Update()
fmt.Println(estimateSec)
//调用python程序
var url = "http://" + computePower["baseUrl"] + "/sdapi/v1/txt2img"
genImageReq := make(map[string]string)
genImageReq["prompt"] = genImage.Prompt
bytesData, _ := json.Marshal(genImageReq)
//执行时间计时processSec 更新执行时间以及状态
processSecStart := time.Now()
res, err := http.Post(url, "application/json", bytes.NewBuffer(bytesData))
processSec := time.Since(processSecStart)
userAiReq.ProcessSec = int64(processSec)
userAiReq.Status = 3
userAiReq.Update()
fmt.Println(processSec)
if err != nil {
fmt.Println("Fatal error ", err.Error())
}
//更新算力状态 update
computePowerMap["status"] = 1
updateComputePower(computePowerMap)
defer res.Body.Close()
bodyRes, _ := io.ReadAll(res.Body)
var imagesStruct ImagesStruct
err = json.Unmarshal(bodyRes, &imagesStruct)
if err != nil {
logrus.Error(err)
}
fmt.Println(bodyRes)
imgStr, err := base64.StdEncoding.DecodeString(imagesStruct.Images[0])
if err != nil {
logrus.Error("base64 decode error:", err)
}
//生成图片
timeNow := time.Now().Unix()
fileName := "./stable-diffusion-webui/txt2img-images/" + strconv.FormatInt(timeNow, 10) + ".png"
err = ImgUpload(fileName, string(imgStr))
if err != nil {
logrus.Error(err)
}
//上传图片
targetUrl := "http://static.rejo9.com/stable-diffusion-webui/txt2img-images/"
aaa, err := PostFile(fileName, targetUrl)
if err != nil {
logrus.Error(err)
}
logrus.Info(aaa)
}
return
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment