edit : dalle, imagen 추가
This commit is contained in:
257
custom_apps/dalle3/BingImageCreator.py
Normal file
257
custom_apps/dalle3/BingImageCreator.py
Normal file
@@ -0,0 +1,257 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Union
|
||||
|
||||
import httpx
|
||||
import pkg_resources
|
||||
import regex
|
||||
import requests
|
||||
|
||||
from rest.app.utils.parsing_utils import prompt_to_filenames
|
||||
from rest.app.utils.date_utils import D
|
||||
|
||||
BING_URL = os.getenv("BING_URL", "https://www.bing.com")
|
||||
# Generate random IP between range 13.104.0.0/14
|
||||
FORWARDED_IP = (
|
||||
f"13.{random.randint(104, 107)}.{random.randint(0, 255)}.{random.randint(0, 255)}"
|
||||
)
|
||||
HEADERS = {
|
||||
"accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7",
|
||||
"accept-language": "en-US,en;q=0.9",
|
||||
"cache-control": "max-age=0",
|
||||
"content-type": "application/x-www-form-urlencoded",
|
||||
"referrer": "https://www.bing.com/images/create/",
|
||||
"origin": "https://www.bing.com",
|
||||
"user-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/110.0.0.0 Safari/537.36 Edg/110.0.1587.63",
|
||||
"x-forwarded-for": FORWARDED_IP,
|
||||
}
|
||||
|
||||
# Error messages
|
||||
error_timeout = "Your request has timed out."
|
||||
error_redirect = "Redirect failed"
|
||||
error_blocked_prompt = (
|
||||
"Your prompt has been blocked by Bing. Try to change any bad words and try again."
|
||||
)
|
||||
error_being_reviewed_prompt = "Your prompt is being reviewed by Bing. Try to change any sensitive words and try again."
|
||||
error_noresults = "Could not get results"
|
||||
error_unsupported_lang = "\nthis language is currently not supported by bing"
|
||||
error_bad_images = "Bad images"
|
||||
error_no_images = "No images"
|
||||
# Action messages
|
||||
sending_message = "Sending request..."
|
||||
wait_message = "Waiting for results..."
|
||||
download_message = "\nDownloading images..."
|
||||
|
||||
|
||||
def debug(debug_file, text_var):
|
||||
"""helper function for debug"""
|
||||
with open(f"{debug_file}", "a", encoding="utf-8") as f:
|
||||
f.write(str(text_var))
|
||||
f.write("\n")
|
||||
|
||||
|
||||
class ImageGenAsync:
|
||||
"""
|
||||
Image generation by Microsoft Bing
|
||||
Parameters:
|
||||
auth_cookie: str
|
||||
Optional Parameters:
|
||||
debug_file: str
|
||||
quiet: bool
|
||||
all_cookies: list[dict]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
auth_cookie: str = None,
|
||||
debug_file: Union[str, None] = None,
|
||||
quiet: bool = False,
|
||||
all_cookies: List[Dict] = None,
|
||||
) -> None:
|
||||
if auth_cookie is None and not all_cookies:
|
||||
raise Exception("No auth cookie provided")
|
||||
self.session = httpx.AsyncClient(
|
||||
headers=HEADERS,
|
||||
trust_env=True,
|
||||
)
|
||||
if auth_cookie:
|
||||
self.session.cookies.update({"_U": auth_cookie})
|
||||
if all_cookies:
|
||||
for cookie in all_cookies:
|
||||
self.session.cookies.update(
|
||||
{cookie["name"]: cookie["value"]},
|
||||
)
|
||||
self.quiet = quiet
|
||||
self.debug_file = debug_file
|
||||
if self.debug_file:
|
||||
self.debug = partial(debug, self.debug_file)
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *excinfo) -> None:
|
||||
await self.session.aclose()
|
||||
|
||||
async def get_images(self, prompt: str) -> list:
|
||||
"""
|
||||
Fetches image links from Bing
|
||||
Parameters:
|
||||
prompt: str
|
||||
"""
|
||||
if not self.quiet:
|
||||
print("Sending request...")
|
||||
url_encoded_prompt = requests.utils.quote(prompt)
|
||||
# https://www.bing.com/images/create?q=<PROMPT>&rt=3&FORM=GENCRE
|
||||
url = f"{BING_URL}/images/create?q={url_encoded_prompt}&rt=3&FORM=GENCRE"
|
||||
payload = f"q={url_encoded_prompt}&qs=ds"
|
||||
response = await self.session.post(
|
||||
url,
|
||||
follow_redirects=False,
|
||||
data=payload,
|
||||
)
|
||||
content = response.text
|
||||
if "this prompt has been blocked" in content.lower():
|
||||
raise Exception(
|
||||
"Your prompt has been blocked by Bing. Try to change any bad words and try again.",
|
||||
)
|
||||
if response.status_code != 302:
|
||||
# if rt4 fails, try rt3
|
||||
url = f"{BING_URL}/images/create?q={url_encoded_prompt}&rt=4&FORM=GENCRE"
|
||||
response = await self.session.post(
|
||||
url,
|
||||
follow_redirects=False,
|
||||
timeout=200,
|
||||
)
|
||||
if response.status_code != 302:
|
||||
print(f"ERROR: {response.text}")
|
||||
raise Exception("Redirect failed")
|
||||
# Get redirect URL
|
||||
redirect_url = response.headers["Location"].replace("&nfy=1", "")
|
||||
request_id = redirect_url.split("id=")[-1]
|
||||
await self.session.get(f"{BING_URL}{redirect_url}")
|
||||
# https://www.bing.com/images/create/async/results/{ID}?q={PROMPT}
|
||||
polling_url = f"{BING_URL}/images/create/async/results/{request_id}?q={url_encoded_prompt}"
|
||||
# Poll for results
|
||||
if not self.quiet:
|
||||
print("Waiting for results...")
|
||||
while True:
|
||||
if not self.quiet:
|
||||
print(".", end="", flush=True)
|
||||
# By default, timeout is 300s, change as needed
|
||||
response = await self.session.get(polling_url)
|
||||
if response.status_code != 200:
|
||||
raise Exception("Could not get results")
|
||||
content = response.text
|
||||
if content and content.find("errorMessage") == -1:
|
||||
break
|
||||
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
# Use regex to search for src=""
|
||||
image_links = regex.findall(r'src="([^"]+)"', content)
|
||||
# Remove size limit
|
||||
normal_image_links = [link.split("?w=")[0] for link in image_links]
|
||||
# Remove duplicates
|
||||
normal_image_links = list(set(normal_image_links))
|
||||
|
||||
# Bad images
|
||||
bad_images = [
|
||||
"https://r.bing.com/rp/in-2zU3AJUdkgFe7ZKv19yPBHVs.png",
|
||||
"https://r.bing.com/rp/TX9QuO3WzcCJz1uaaSwQAz39Kb0.jpg",
|
||||
]
|
||||
for im in normal_image_links:
|
||||
if im in bad_images:
|
||||
raise Exception("Bad images")
|
||||
# No images
|
||||
if not normal_image_links:
|
||||
raise Exception("No images")
|
||||
return normal_image_links
|
||||
|
||||
async def save_images(
|
||||
self,
|
||||
links: list,
|
||||
output_dir: str,
|
||||
download_count: int,
|
||||
file_name: str = None,
|
||||
prompt: str = None
|
||||
) -> None:
|
||||
"""
|
||||
Saves images to output directory
|
||||
"""
|
||||
|
||||
if self.debug_file:
|
||||
self.debug(download_message)
|
||||
if not self.quiet:
|
||||
print(download_message)
|
||||
with contextlib.suppress(FileExistsError):
|
||||
os.mkdir(output_dir)
|
||||
try:
|
||||
model = "dalle3"
|
||||
jpeg_index = 0
|
||||
parsing_file_name = prompt_to_filenames(prompt)
|
||||
|
||||
for link in links[:download_count]:
|
||||
if download_count == 1:
|
||||
_path = os.path.join(output_dir, f"{model}_{parsing_file_name}_{D.date_file_name()}.jpg")
|
||||
if os.path.exists(_path):
|
||||
raise Exception("파일 이미 존재함")
|
||||
|
||||
while os.path.exists(_path):
|
||||
jpeg_index += 1
|
||||
response = await self.session.get(link)
|
||||
if response.status_code != 200:
|
||||
raise Exception("Could not download image")
|
||||
# save response to file
|
||||
with open(_path,"wb") as output_file:
|
||||
output_file.write(response.content)
|
||||
jpeg_index += 1
|
||||
|
||||
else:
|
||||
_path = os.path.join(output_dir, f"{model}_{parsing_file_name}_{jpeg_index}_{D.date_file_name()}.png")
|
||||
if os.path.exists(_path):
|
||||
raise Exception("파일 이미 존재함")
|
||||
|
||||
while os.path.exists(_path):
|
||||
jpeg_index += 1
|
||||
response = await self.session.get(link)
|
||||
if response.status_code != 200:
|
||||
raise Exception("Could not download image")
|
||||
# save response to file
|
||||
with open(_path,"wb") as output_file:
|
||||
output_file.write(response.content)
|
||||
jpeg_index += 1
|
||||
|
||||
except httpx.InvalidURL as url_exception:
|
||||
raise Exception(
|
||||
"Inappropriate contents found in the generated images. Please try again or try another prompt.",
|
||||
) from url_exception
|
||||
|
||||
|
||||
async def async_image_gen(
|
||||
prompt: str,
|
||||
download_count: int,
|
||||
output_dir: str,
|
||||
u_cookie=None,
|
||||
debug_file=None,
|
||||
quiet=False,
|
||||
all_cookies=None,
|
||||
):
|
||||
async with ImageGenAsync(
|
||||
u_cookie,
|
||||
debug_file=debug_file,
|
||||
quiet=quiet,
|
||||
all_cookies=all_cookies,
|
||||
) as image_generator:
|
||||
images = await image_generator.get_images(prompt)
|
||||
await image_generator.save_images(
|
||||
images, output_dir=output_dir, download_count=download_count,prompt=prompt
|
||||
)
|
||||
88
custom_apps/dalle3/custom_dalle.py
Normal file
88
custom_apps/dalle3/custom_dalle.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import nest_asyncio
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
from custom_apps.dalle3.BingImageCreator import *
|
||||
from const import OUTPUT_FOLDER
|
||||
|
||||
|
||||
class CookieManager:
|
||||
|
||||
DEFAULT_COOKIE = "15yugVy08XGWEWtpv2SW7-mG_0HsBxbDyBFQDdnKEEK6c-XkjYt4HOk6G_5wY4npT0yKB9yEbl76i7RB5CM_HocUZ-nIOIseVMFVPkg-aTeA82BrEjVSh6ohZaz3rUk9Lw0NpDCV60Mn8s4nyQo8vSZJDlsqVEYXklSyKbEbnrtLPNcXUY5gl9Fmjz5Lxr1CMiBr8ogt3UvmWUBYDA-b4-6SEranbD_2wY_KSVN0djJg"
|
||||
|
||||
def __init__(self):
|
||||
self.cookie = self.DEFAULT_COOKIE
|
||||
|
||||
def get_cookie(self):
|
||||
if not self.cookie:
|
||||
raise
|
||||
else:
|
||||
return self.cookie
|
||||
|
||||
def set_cookie(self,new_cookie):
|
||||
self.cookie = new_cookie
|
||||
|
||||
cookie_manager = CookieManager()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DallEArgument:
|
||||
"""
|
||||
U : Auth cookie from browser
|
||||
cookie_file : File containing auth cookie
|
||||
prompt: Prompt to generate images for
|
||||
output_dir: Output directory
|
||||
download_count: Number of images to download, value must be less than five
|
||||
debug_file: Path to the file where debug information will be written.
|
||||
quiet: Disable pipeline messages
|
||||
asyncio: Run ImageGen using asyncio
|
||||
version: Print the version number
|
||||
"""
|
||||
U = cookie_manager.get_cookie()
|
||||
prompt: str
|
||||
cookie_file: str|None = None
|
||||
output_dir: str = os.path.join(OUTPUT_FOLDER,"dalle")
|
||||
download_count: int = 1
|
||||
debug_file: str|None = None
|
||||
quiet: bool = False
|
||||
asyncio: bool = True
|
||||
version: bool = False
|
||||
|
||||
|
||||
def dalle3_generate_image(args):
|
||||
|
||||
nest_asyncio.apply()
|
||||
|
||||
if not os.path.isdir(args.output_dir):
|
||||
raise FileExistsError(f"FileExistsError: {args.output_dir}")
|
||||
|
||||
if args.version:
|
||||
print(pkg_resources.get_distribution("BingImageCreator").version)
|
||||
sys.exit()
|
||||
|
||||
# Load auth cookie
|
||||
cookie_json = None
|
||||
if args.cookie_file is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
with open(args.cookie_file, encoding="utf-8") as file:
|
||||
cookie_json = json.load(file)
|
||||
|
||||
if args.U is None and args.cookie_file is None:
|
||||
raise Exception("Could not find auth cookie")
|
||||
|
||||
if args.download_count > 4:
|
||||
raise Exception("The number of downloads must be less than five")
|
||||
|
||||
|
||||
asyncio.run(
|
||||
async_image_gen(
|
||||
args.prompt,
|
||||
args.download_count,
|
||||
args.output_dir,
|
||||
args.U,
|
||||
args.debug_file,
|
||||
args.quiet,
|
||||
all_cookies=cookie_json,
|
||||
),
|
||||
)
|
||||
|
||||
1
custom_apps/imagen/const.py
Normal file
1
custom_apps/imagen/const.py
Normal file
@@ -0,0 +1 @@
|
||||
key = "AIzaSyB7tu67y9gOkJkpQtvI5OAYSzUzwv9qwnE"
|
||||
46
custom_apps/imagen/custom_imagen.py
Normal file
46
custom_apps/imagen/custom_imagen.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import vertexai
|
||||
import os
|
||||
from vertexai.preview.vision_models import ImageGenerationModel
|
||||
|
||||
from const import OUTPUT_FOLDER
|
||||
from rest.app.utils.parsing_utils import prompt_to_filenames
|
||||
from rest.app.utils.date_utils import D
|
||||
|
||||
|
||||
class ImagenConst:
|
||||
project_id = "imagen-447301"
|
||||
location = "asia-east1"
|
||||
model = "imagen-3.0-generate-001"
|
||||
|
||||
|
||||
def imagen_generate_image(prompt,download_count=1):
|
||||
vertexai.init(project=ImagenConst.project_id, location=ImagenConst.location)
|
||||
|
||||
model = ImageGenerationModel.from_pretrained(ImagenConst.model)
|
||||
|
||||
_file_name = prompt_to_filenames(prompt)
|
||||
_folder = os.path.join(OUTPUT_FOLDER,"imagen")
|
||||
|
||||
if not os.path.isdir(_folder):
|
||||
raise FileExistsError(f"FileExistsError: {_folder}")
|
||||
|
||||
images = model.generate_images(
|
||||
prompt=prompt,
|
||||
# Optional parameters
|
||||
number_of_images=download_count,
|
||||
language="ko",
|
||||
# You can't use a seed value and watermark at the same time.
|
||||
# add_watermark=False,
|
||||
# seed=100,
|
||||
aspect_ratio="1:1",
|
||||
safety_filter_level="block_some",
|
||||
person_generation="dont_allow",
|
||||
)
|
||||
|
||||
if len(images.images) <=1 :
|
||||
images[0].save(location=os.path.join(_folder,f"imagen_{_file_name}_{D.date_file_name()}.png"), include_generation_parameters=False)
|
||||
else:
|
||||
for i in range(len(images.images)):
|
||||
images[i].save(location=os.path.join(_folder,f"imagen_{_file_name}_{i}_{D.date_file_name()}.png"), include_generation_parameters=False)
|
||||
Reference in New Issue
Block a user