์ง ์ปดํจํฐ๊ฐ ์ข์ง ๋ชปํด stable diffusion์ colab์์ ์ฌ์ฉํด๋ณผ ์ ์๋ ์ฝ๋๊ฐ ์์ด์
๋ฐ๋ผ์ ํด๋ณด๊ณ ์ฌ์ฉํด๋ดค๋๋ฐ ์ฑ๋ฅ์ด web ui๋ณด๋ค ์ข์ง ๋ชปํด ์ค๋งํ๋ค.
์ฝ๋ฉ์ ๋๊ณ ํ์ฌ์์ ์ฌ์ฉํ api๋ฅผ ์ฐพ์๋ณด๊ณ ์์๋๋ฐ ์นด์นด์ค์์๋ text to image๋ฅผ ์๋น์คํ๊ณ ์์๋ค.
๊ทธ๋์ ํ๋ฒ ์ฌ์ฉํด๋ณด๊ณ , ํด๋ณด๊ณ ์ถ์ ์ฌ๋๋ค์ด ์์ ์ ์์ ๊ฒ ๊ฐ์ ์ฌ์ฉ๋ฐฉ๋ฒ์ ์๋ ค์ฃผ๋ ค๊ณ ํ๋ค.
๋ฏธ๋ฆฌ ๋งํ์ง๋ง ์ฑ๋ฅ์ web ui๋ณด๋ค ์ข์ง ๋ชปํ๋ค.
์๋ฒ์์ ์ ํด์ง ๋ชจ๋ธ์ ํ๋กฌํฌํธ๋ง ์ ๋ ฅํ์ฌ ์ด๋ฏธ์ง๋ฅผ ๊ฐ์ง๊ณ ์ค๋๊ฑฐ๋ผ
์ข์ ์ฑ๋ฅ์ ๊ธฐ๋ํ๋ ค๋ฉด webui๋ฅผ ๊ฐ์ ๊ณต๋ถํ๊ณ ์ฌ์ฉํ๋๊ฒ ์ข๋ค.
์ทจ์ง๋ ์นด์นด์คํก์์ text to image api๋ฅผ ๋ง๋ค์์ผ๋ ํ๋ฒ ์ฐ์ด๋จน์ด๋ณด์๋ผ๋ ์๋ฏธ์ด๋ค.
๋จผ์ ์นด์นด์ค๋๋ฒจ๋กํผ์ ๋ค์ด๊ฐ์ ํ์๊ฐ์ ์ ํ๋ค.
ํ์๊ฐ์ ์ ํ๊ณ '๋ด ์ ํ๋ฆฌ์ผ์ด์ '์ ๋๋ฅด๊ณ '์ ํ๋ฆฌ์ผ์ด์ ์ถ๊ฐํ๊ธฐ'๋ฅผ ๋๋ฅธ๋ค.
์ฑ ์ด๋ฆ๊ณผ ์ฌ์ ์๋ช ์ ์ ์ด์ค๋ค. ๊ทธ๋ฅ ๊ณต๋ถํ ๋ชฉ์ ์ด๋ฉด ์ ํ๋ฆฌ์ผ์ด์ ์ด ์ด๋ค ์ฉ๋๋ก ์ฐ์ผ์ง ์ ์ด์ฃผ๋ฉด ๋๋ค.
๊ทธ๋ฆฌ๊ณ ์ฒดํฌ๋ฐ์ค์ ์ฒดํฌํ๊ณ ์ ์ฅํ๋ฉด ๋๋ค.
์ ์ฅํ์ฌ ๋ง๋ค์ด์ง ์ ํ๋ฆฌ์ผ์ด์ ์ ํด๋ฆญํ๋ฉด ์ฑ ํค๋ฅผ ํ์ธ ํ ์ ์๋ค.
์ฐ๋ฆฌ๋ api๋ฅผ ์ฌ์ฉํ ๊ฒ์ด๋ผ REST API ํค๋ฅผ ๋ณต์ฌํด์ ์ฌ์ฉํ ๊ฒ์ด๋ค.
๊ทธ๋ค์ ์ฝ๋ฉ์ ํค๊ฑฐ๋ vscode๋ฅผ ์ผ์ ์ฝ๋๋ฅผ ์ ๋ ฅํ๋ฉด ๋๋ค.
๋๋ ์ฝ๋ฉ์์ ์ฌ์ฉํด์ ์ฝ๋ฉ๊ธฐ์ค์ผ๋ก ์๋ ค์ฃผ๋ ค ํ๋ค.
# REST API ํธ์ถ, ์ด๋ฏธ์ง ํ์ผ ์ฒ๋ฆฌ์ ํ์ํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ
import requests
import json
import io
import base64
from PIL import Image
# [๋ด ์ ํ๋ฆฌ์ผ์ด์
] > [์ฑ ํค] ์์ ํ์ธํ REST API ํค ๊ฐ ์
๋ ฅ
REST_API_KEY = '๋ณธ์ธ์ rest apiํค๋ฅผ ์
๋ ฅ'
์ฌ๊ธฐ์ ์นด์นด์ค ์์ ์์๋ '${rest_api}' ์์ง๋ง ${}๋ฅผ ์์ง์ฐ๋ฉด
์๋ฌ์ฒ๋ฆฌ๊ฐ ๋์ ์ง์์ ์ ๋ ฅํ์๋๋ ์๋ต์ด ์ ์๋ค.
# ์ด๋ฏธ์ง ์์ฑํ๊ธฐ ์์ฒญ
def t2i(text, batch_size=1):
r = requests.post( # ์นด์นด์คํํ
์์ฒญํ๊ฒ ๋ค.
'https://api.kakaobrain.com/v1/inference/karlo/t2i', # ์ด์ฌ์ดํธ๋ก
json = {
'prompt': {
'text': text, # ๋ด๊ฐ ์ํ๋ ๊ทธ๋ฆผ์ ํ๋กฌํฌํธ๋ text์ด๊ณ ,
'batch_size': batch_size # batch_size๋งํผ ๊ทธ๋ ค์ค
}
},
headers = {
'Authorization': f'KakaoAK {REST_API_KEY}', # ์์ฒญํ ๋ ํ์ํ ์ํธ
'Content-Type': 'application/json'
}
)
# ์๋ต JSON ํ์์ผ๋ก ๋ณํ
response = json.loads(r.content)
return response
# Base64 ๋์ฝ๋ฉ ๋ฐ ๋ณํ (์ด๋ฏธ์ง๋ก ๋ฐ๋ก ๋ณด๋ด๋ฉด ๋ญ๋น)
def stringToImage(base64_string, mode='RGBA'):
imgdata = base64.b64decode(str(base64_string))
img = Image.open(io.BytesIO(imgdata)).convert(mode)
return img
์ฝ๋๋ฅผ ๊ฐ๋จํ ์ค๋ช ํ๋ ค๊ณ ์ฃผ์์ฒ๋ฆฌ๋ฅผ ๋ฌ์๋ค. text์ ๋ค์ด๊ฐ๋๊ฑด ์ด๋ค ๊ทธ๋ฆผ์ด ๋์์ผ๋ฉด ์ข์์ง ์ ๋ ๊ฒ์ด๋ค.
batch_size๋ ์ฒ์์ ํํฐ๋ฆฌ๊ฐ ์ข์์ง๋์ง ์์๋๋ฐ ๊ทธ๋ฆผ์ ์ฅ ์์ด๋ค.
1~8์ฅ๊น์ง ์ธ ์ ์๋ค.
# ํ๋กฌํํธ์ ์ฌ์ฉํ ์ ์์ด
text = "a baby in a cute penguin suit facing the front"
# ์ด๋ฏธ์ง ์์ฑํ๊ธฐ REST API ํธ์ถ
response = t2i(text, 8)
text๋ ์์ด๋ฐ์ ์๋๊ณ 256์๊น์ง ์ธ ์ ์๋ค. ๋์ฒ๋ผ ๋ง์ด ๋๊ฒ ์จ๋ ๊ด์ฐฎ๊ณ ,
' , ' ๋ฅผ ๊ธฐ์ค์ผ๋ก ์๋ฌด ๋จ์ด๋ ๋์ดํด๋ ์์์ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๋ค.
# ์๋ต์ ์ฒซ ๋ฒ์งธ ์ด๋ฏธ์ง ์์ฑ ๊ฒฐ๊ณผ ์ถ๋ ฅํ๊ธฐ
result = stringToImage(response.get("images")[0].get("image"), mode='RGB')
result
response.get("images")[0] ์ด ๋ถ๋ถ์ 0~7๊น์ง ๊ณ ์น๋ฉด ๋๋ค. batch_size-1 ๊น์ง ๋๋ฆด ์ ์๋ค.
๋ง์ฝ ์ฌ๊ธฐ์ ์ด๋ฏธ์ง๊ฐ ์ถ๋ ฅ์ด ๋์ง ์๋๋ค๋ฉด response๋ฅผ ์ฐ์ด๋ณด๋ฉด ์๋ฌ ์๋ต์ ์ฃผ์์ ํ๋ฅ ์ด ๋๋ค.
๋๋ ์ฒ์์ batch_size๋ฅผ 10์ ์ฃผ์๋๋ฐ ์๊พธ bad requests๊ฐ ์์ ์ฐพ์๋ณด๋ 8๊น์ง๋ฐ์ ์๋๋ค๊ณ ์จ์ ธ์์๋ค.
๋ค๋ฅธ ๊ธฐ๋ฅ๋ 2๊ฐ์ง ์ ๋ ๋ ์์ผ๋ ๊ถ๊ธํ๋ฉด ์ง์ ํด๋ณด๋ฉด ์ฌ๋ฐ์ ๊ฒ์ด๋ค.
๋๊ธ