edit : dalle, imagen 추가

This commit is contained in:
2025-01-16 15:17:12 +09:00
parent bd28fcdca4
commit 9ed67d3e50
15 changed files with 503 additions and 222 deletions

View File

@@ -0,0 +1,257 @@
import argparse
import asyncio
import contextlib
import json
import os
import random
import sys
import time
from functools import partial
from typing import Dict
from typing import List
from typing import Union
import httpx
import pkg_resources
import regex
import requests
from rest.app.utils.parsing_utils import prompt_to_filenames
from rest.app.utils.date_utils import D
BING_URL = os.getenv("BING_URL", "https://www.bing.com")
# Generate random IP between range 13.104.0.0/14
FORWARDED_IP = (
f"13.{random.randint(104, 107)}.{random.randint(0, 255)}.{random.randint(0, 255)}"
)
HEADERS = {
"accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7",
"accept-language": "en-US,en;q=0.9",
"cache-control": "max-age=0",
"content-type": "application/x-www-form-urlencoded",
"referrer": "https://www.bing.com/images/create/",
"origin": "https://www.bing.com",
"user-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/110.0.0.0 Safari/537.36 Edg/110.0.1587.63",
"x-forwarded-for": FORWARDED_IP,
}
# Error messages
error_timeout = "Your request has timed out."
error_redirect = "Redirect failed"
error_blocked_prompt = (
"Your prompt has been blocked by Bing. Try to change any bad words and try again."
)
error_being_reviewed_prompt = "Your prompt is being reviewed by Bing. Try to change any sensitive words and try again."
error_noresults = "Could not get results"
error_unsupported_lang = "\nthis language is currently not supported by bing"
error_bad_images = "Bad images"
error_no_images = "No images"
# Action messages
sending_message = "Sending request..."
wait_message = "Waiting for results..."
download_message = "\nDownloading images..."
def debug(debug_file, text_var):
"""helper function for debug"""
with open(f"{debug_file}", "a", encoding="utf-8") as f:
f.write(str(text_var))
f.write("\n")
class ImageGenAsync:
"""
Image generation by Microsoft Bing
Parameters:
auth_cookie: str
Optional Parameters:
debug_file: str
quiet: bool
all_cookies: list[dict]
"""
def __init__(
self,
auth_cookie: str = None,
debug_file: Union[str, None] = None,
quiet: bool = False,
all_cookies: List[Dict] = None,
) -> None:
if auth_cookie is None and not all_cookies:
raise Exception("No auth cookie provided")
self.session = httpx.AsyncClient(
headers=HEADERS,
trust_env=True,
)
if auth_cookie:
self.session.cookies.update({"_U": auth_cookie})
if all_cookies:
for cookie in all_cookies:
self.session.cookies.update(
{cookie["name"]: cookie["value"]},
)
self.quiet = quiet
self.debug_file = debug_file
if self.debug_file:
self.debug = partial(debug, self.debug_file)
async def __aenter__(self):
return self
async def __aexit__(self, *excinfo) -> None:
await self.session.aclose()
async def get_images(self, prompt: str) -> list:
"""
Fetches image links from Bing
Parameters:
prompt: str
"""
if not self.quiet:
print("Sending request...")
url_encoded_prompt = requests.utils.quote(prompt)
# https://www.bing.com/images/create?q=<PROMPT>&rt=3&FORM=GENCRE
url = f"{BING_URL}/images/create?q={url_encoded_prompt}&rt=3&FORM=GENCRE"
payload = f"q={url_encoded_prompt}&qs=ds"
response = await self.session.post(
url,
follow_redirects=False,
data=payload,
)
content = response.text
if "this prompt has been blocked" in content.lower():
raise Exception(
"Your prompt has been blocked by Bing. Try to change any bad words and try again.",
)
if response.status_code != 302:
# if rt4 fails, try rt3
url = f"{BING_URL}/images/create?q={url_encoded_prompt}&rt=4&FORM=GENCRE"
response = await self.session.post(
url,
follow_redirects=False,
timeout=200,
)
if response.status_code != 302:
print(f"ERROR: {response.text}")
raise Exception("Redirect failed")
# Get redirect URL
redirect_url = response.headers["Location"].replace("&nfy=1", "")
request_id = redirect_url.split("id=")[-1]
await self.session.get(f"{BING_URL}{redirect_url}")
# https://www.bing.com/images/create/async/results/{ID}?q={PROMPT}
polling_url = f"{BING_URL}/images/create/async/results/{request_id}?q={url_encoded_prompt}"
# Poll for results
if not self.quiet:
print("Waiting for results...")
while True:
if not self.quiet:
print(".", end="", flush=True)
# By default, timeout is 300s, change as needed
response = await self.session.get(polling_url)
if response.status_code != 200:
raise Exception("Could not get results")
content = response.text
if content and content.find("errorMessage") == -1:
break
await asyncio.sleep(1)
continue
# Use regex to search for src=""
image_links = regex.findall(r'src="([^"]+)"', content)
# Remove size limit
normal_image_links = [link.split("?w=")[0] for link in image_links]
# Remove duplicates
normal_image_links = list(set(normal_image_links))
# Bad images
bad_images = [
"https://r.bing.com/rp/in-2zU3AJUdkgFe7ZKv19yPBHVs.png",
"https://r.bing.com/rp/TX9QuO3WzcCJz1uaaSwQAz39Kb0.jpg",
]
for im in normal_image_links:
if im in bad_images:
raise Exception("Bad images")
# No images
if not normal_image_links:
raise Exception("No images")
return normal_image_links
async def save_images(
self,
links: list,
output_dir: str,
download_count: int,
file_name: str = None,
prompt: str = None
) -> None:
"""
Saves images to output directory
"""
if self.debug_file:
self.debug(download_message)
if not self.quiet:
print(download_message)
with contextlib.suppress(FileExistsError):
os.mkdir(output_dir)
try:
model = "dalle3"
jpeg_index = 0
parsing_file_name = prompt_to_filenames(prompt)
for link in links[:download_count]:
if download_count == 1:
_path = os.path.join(output_dir, f"{model}_{parsing_file_name}_{D.date_file_name()}.jpg")
if os.path.exists(_path):
raise Exception("파일 이미 존재함")
while os.path.exists(_path):
jpeg_index += 1
response = await self.session.get(link)
if response.status_code != 200:
raise Exception("Could not download image")
# save response to file
with open(_path,"wb") as output_file:
output_file.write(response.content)
jpeg_index += 1
else:
_path = os.path.join(output_dir, f"{model}_{parsing_file_name}_{jpeg_index}_{D.date_file_name()}.png")
if os.path.exists(_path):
raise Exception("파일 이미 존재함")
while os.path.exists(_path):
jpeg_index += 1
response = await self.session.get(link)
if response.status_code != 200:
raise Exception("Could not download image")
# save response to file
with open(_path,"wb") as output_file:
output_file.write(response.content)
jpeg_index += 1
except httpx.InvalidURL as url_exception:
raise Exception(
"Inappropriate contents found in the generated images. Please try again or try another prompt.",
) from url_exception
async def async_image_gen(
prompt: str,
download_count: int,
output_dir: str,
u_cookie=None,
debug_file=None,
quiet=False,
all_cookies=None,
):
async with ImageGenAsync(
u_cookie,
debug_file=debug_file,
quiet=quiet,
all_cookies=all_cookies,
) as image_generator:
images = await image_generator.get_images(prompt)
await image_generator.save_images(
images, output_dir=output_dir, download_count=download_count,prompt=prompt
)

View File

@@ -0,0 +1,88 @@
import nest_asyncio
import os
from dataclasses import dataclass
from custom_apps.dalle3.BingImageCreator import *
from const import OUTPUT_FOLDER
class CookieManager:
DEFAULT_COOKIE = "15yugVy08XGWEWtpv2SW7-mG_0HsBxbDyBFQDdnKEEK6c-XkjYt4HOk6G_5wY4npT0yKB9yEbl76i7RB5CM_HocUZ-nIOIseVMFVPkg-aTeA82BrEjVSh6ohZaz3rUk9Lw0NpDCV60Mn8s4nyQo8vSZJDlsqVEYXklSyKbEbnrtLPNcXUY5gl9Fmjz5Lxr1CMiBr8ogt3UvmWUBYDA-b4-6SEranbD_2wY_KSVN0djJg"
def __init__(self):
self.cookie = self.DEFAULT_COOKIE
def get_cookie(self):
if not self.cookie:
raise
else:
return self.cookie
def set_cookie(self,new_cookie):
self.cookie = new_cookie
cookie_manager = CookieManager()
@dataclass
class DallEArgument:
"""
U : Auth cookie from browser
cookie_file : File containing auth cookie
prompt: Prompt to generate images for
output_dir: Output directory
download_count: Number of images to download, value must be less than five
debug_file: Path to the file where debug information will be written.
quiet: Disable pipeline messages
asyncio: Run ImageGen using asyncio
version: Print the version number
"""
U = cookie_manager.get_cookie()
prompt: str
cookie_file: str|None = None
output_dir: str = os.path.join(OUTPUT_FOLDER,"dalle")
download_count: int = 1
debug_file: str|None = None
quiet: bool = False
asyncio: bool = True
version: bool = False
def dalle3_generate_image(args):
nest_asyncio.apply()
if not os.path.isdir(args.output_dir):
raise FileExistsError(f"FileExistsError: {args.output_dir}")
if args.version:
print(pkg_resources.get_distribution("BingImageCreator").version)
sys.exit()
# Load auth cookie
cookie_json = None
if args.cookie_file is not None:
with contextlib.suppress(Exception):
with open(args.cookie_file, encoding="utf-8") as file:
cookie_json = json.load(file)
if args.U is None and args.cookie_file is None:
raise Exception("Could not find auth cookie")
if args.download_count > 4:
raise Exception("The number of downloads must be less than five")
asyncio.run(
async_image_gen(
args.prompt,
args.download_count,
args.output_dir,
args.U,
args.debug_file,
args.quiet,
all_cookies=cookie_json,
),
)

View File

@@ -0,0 +1 @@
key = "AIzaSyB7tu67y9gOkJkpQtvI5OAYSzUzwv9qwnE"

View File

@@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-
import vertexai
import os
from vertexai.preview.vision_models import ImageGenerationModel
from const import OUTPUT_FOLDER
from rest.app.utils.parsing_utils import prompt_to_filenames
from rest.app.utils.date_utils import D
class ImagenConst:
project_id = "imagen-447301"
location = "asia-east1"
model = "imagen-3.0-generate-001"
def imagen_generate_image(prompt,download_count=1):
vertexai.init(project=ImagenConst.project_id, location=ImagenConst.location)
model = ImageGenerationModel.from_pretrained(ImagenConst.model)
_file_name = prompt_to_filenames(prompt)
_folder = os.path.join(OUTPUT_FOLDER,"imagen")
if not os.path.isdir(_folder):
raise FileExistsError(f"FileExistsError: {_folder}")
images = model.generate_images(
prompt=prompt,
# Optional parameters
number_of_images=download_count,
language="ko",
# You can't use a seed value and watermark at the same time.
# add_watermark=False,
# seed=100,
aspect_ratio="1:1",
safety_filter_level="block_some",
person_generation="dont_allow",
)
if len(images.images) <=1 :
images[0].save(location=os.path.join(_folder,f"imagen_{_file_name}_{D.date_file_name()}.png"), include_generation_parameters=False)
else:
for i in range(len(images.images)):
images[i].save(location=os.path.join(_folder,f"imagen_{_file_name}_{i}_{D.date_file_name()}.png"), include_generation_parameters=False)