Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os | |
| import re | |
| from claude import embed_base64_for_claude, create_claude_image_request_for_image_captioning, \ | |
| create_claude_request_for_text_completion, extract_data_from_text_xml | |
| from prompts import prompts | |
| from constants import JSON_SCHEMA_FOR_GPT, UPDATED_MODEL_ONLY_SCHEMA, JSON_SCHEMA_FOR_LOC_ONLY | |
| from gpt import runAssistant, checkRunStatus, retrieveThread, createAssistant, saveFileOpenAI, startAssistantThread, \ | |
| create_chat_completion_request_open_ai_for_summary, addMessageToThread, create_image_completion_request_gpt | |
| from summarizer import create_brand_html, create_langchain_openai_query, create_screenshot_from_scrap_fly, check_and_compress_image | |
| from theme import flux_generated_image, flux_generated_image_seed | |
| import time | |
| from PIL import Image | |
| import io | |
| from streamlit_gsheets import GSheetsConnection | |
| # conn = st.connection("gsheets", type=GSheetsConnection) | |
| def process_run(st, thread_id, assistant_id): | |
| run_id = runAssistant(thread_id, assistant_id) | |
| status = 'running' | |
| while status != 'completed': | |
| with st.spinner('. . .'): | |
| time.sleep(20) | |
| status = checkRunStatus(thread_id, run_id) | |
| thread_messages = retrieveThread(thread_id) | |
| for message in thread_messages: | |
| if not message['role'] == 'user': | |
| return message["content"] | |
| else: | |
| pass | |
| def page5(): | |
| st.title('Initialize your preferences!') | |
| system_prompt_passed = st.text_area("System Prompt", value=prompts["PROMPT_FOR_MOOD_AND_IDEA"], | |
| key="System Prompt") | |
| caption_system_prompt = st.text_area("Captioning System Prompt", value=prompts["CAPTION_SYSTEM_PROMPT"], | |
| key="Caption Generation System Prompt") | |
| caption_prompt = st.text_area("Caption Prompt", value=prompts["CAPTION_PROMPT"], | |
| key="Caption Generation Prompt") | |
| brand_summary_prompt = st.text_area("Prompt for Brand Summary", value=prompts["BRAND_SUMMARY_PROMPT"], | |
| key="Brand summary prompt") | |
| st.text("Running on Claude") | |
| col1, col2 = st.columns([1, 2]) | |
| with col1: | |
| if st.button("Save the Prompt"): | |
| st.session_state["system_prompt"] = system_prompt_passed | |
| print(st.session_state["system_prompt"]) | |
| st.session_state["caption_system_prompt"] = caption_system_prompt | |
| st.session_state["caption_prompt"] = caption_prompt | |
| st.session_state["brand_prompt"] = brand_summary_prompt | |
| st.success("Saved your prompts") | |
| with col2: | |
| if st.button("Start Testing!"): | |
| st.session_state['page'] = "Page 1" | |
| def page1(): | |
| st.title("Upload Product") | |
| st.markdown("<h2 style='color:#FF5733; font-weight:bold;'>Add a Product</h2>", unsafe_allow_html=True) | |
| st.markdown("<p style='color:#444;'>Upload your product images, more images you upload better the AI learns</p>", | |
| unsafe_allow_html=True) | |
| uploaded_files = st.file_uploader("Upload Images", accept_multiple_files=True, key="uploaded_files_key") | |
| product_description = st.text_area("Describe the product", value=st.session_state.get("product_description", "")) | |
| col1, col2 = st.columns([1, 2]) | |
| with col1: | |
| if st.button("Save"): | |
| st.session_state['uploaded_files'] = uploaded_files | |
| st.session_state['product_description'] = product_description | |
| st.success("Product information saved!") | |
| with col2: | |
| if st.button("Add product and move to next page"): | |
| if not uploaded_files: | |
| st.warning("Please upload at least one image.") | |
| elif not product_description: | |
| st.warning("Please provide a description for the product.") | |
| else: | |
| st.session_state['uploaded_files'] = uploaded_files | |
| st.session_state['product_description'] = product_description | |
| st.session_state['page'] = "Page 2" | |
| def page2(): | |
| import random | |
| st.title("Tell us about your shoot preference") | |
| st.markdown("<h3 style='color:#444;'>What are you shooting today?</h3>", unsafe_allow_html=True) | |
| shoot_type = st.radio("Select your shoot type:", ["Editorial", "Catalogue"], index=0) | |
| st.session_state['shoot_type'] = shoot_type | |
| brand_link = st.text_input("Add your brand link:", value=st.session_state.get("brand_link", "")) | |
| st.session_state['brand_link'] = brand_link | |
| if st.button("Get Brand Summary"): | |
| if brand_link: | |
| st.text("Using Scrapfly") | |
| brand_summary_html = create_screenshot_from_scrap_fly(brand_link) | |
| if brand_summary_html["success"]: | |
| # compressed_image = f"comp_brand_{random.randint(1, 100000000)}.png" | |
| # comp = check_and_compress_image(brand_summary_html["location"], compressed_image) | |
| # if comp["success"]: | |
| # st.image(compressed_image) | |
| # brand_image_embed = embed_base64_for_claude(compressed_image) | |
| # else: | |
| st.image(brand_summary_html["location"]) | |
| # brand_image_embed = embed_base64_for_claude(brand_summary_html["location"]) | |
| brand_summary_response = create_image_completion_request_gpt(brand_summary_html["location"], st.session_state["brand_prompt"]) | |
| st.session_state['brand_summary'] = brand_summary_response | |
| else: | |
| st.text(f"Scrapfly failed due to: {brand_summary_html}") | |
| st.text("Using Langchain") | |
| brand_summary_html = create_brand_html(brand_link) | |
| brand_summary = create_langchain_openai_query(brand_summary_html) | |
| st.session_state['brand_summary'] = brand_summary | |
| st.success("Brand summary fetched!") | |
| else: | |
| st.warning("Please add a brand link.") | |
| brand_summary_value = st.session_state.get('brand_summary', "") | |
| editable_summary = st.text_area("Brand Summary:", value=brand_summary_value, height=100) | |
| st.session_state['brand_summary'] = editable_summary | |
| product_info = st.text_area("Tell us something about your product:", value=st.session_state.get("product_info", "")) | |
| st.session_state['product_info'] = product_info | |
| reference_images = st.file_uploader("Upload Reference Images", accept_multiple_files=True, | |
| key="reference_images_key") | |
| st.session_state['reference_images'] = reference_images | |
| if st.button("Give Me Ideas"): | |
| st.session_state['page'] = "Page 3" | |
| def page3(): | |
| import random | |
| st.title("Scene Suggestions") | |
| st.write("Based on your uploaded product and references!") | |
| feedback = st.chat_input("Provide feedback:") | |
| if not st.session_state.get("assistant_initialized", False): | |
| file_locations_for_product = [] | |
| for uploaded_file in st.session_state['uploaded_files']: | |
| bytes_data = uploaded_file.getvalue() | |
| image = Image.open(io.BytesIO(bytes_data)) | |
| image.verify() | |
| location = f"temp_image_{random.randint(1, 100000000)}.png" | |
| with open(location, "wb") as f: | |
| f.write(bytes_data) | |
| file_locations_for_product.append(location) | |
| image.close() | |
| file_base64_embeds_product = [embed_base64_for_claude(location) for location in file_locations_for_product] | |
| caption_list_from_claude_product = [] | |
| for file_embeds_base64 in file_base64_embeds_product: | |
| caption_from_claude = create_claude_image_request_for_image_captioning( | |
| st.session_state["caption_system_prompt"], st.session_state["caption_prompt"], file_embeds_base64) | |
| caption_list_from_claude_product.append(caption_from_claude) | |
| string_caption_list_product = str(caption_list_from_claude_product) | |
| file_locations_for_others = [] | |
| for uploaded_file in st.session_state['reference_images']: | |
| bytes_data = uploaded_file.getvalue() | |
| image = Image.open(io.BytesIO(bytes_data)) | |
| image.verify() | |
| location = f"temp2_image_{random.randint(1, 1000000)}.png" | |
| with open(location, "wb") as f: | |
| f.write(bytes_data) | |
| file_locations_for_others.append(location) | |
| image.close() | |
| file_base64_embeds = [embed_base64_for_claude(location) for location in file_locations_for_others] | |
| st.session_state.assistant_initialized = True | |
| caption_list_from_claude = [] | |
| for file_embeds_base64 in file_base64_embeds: | |
| caption_from_claude = create_claude_image_request_for_image_captioning( | |
| st.session_state["caption_system_prompt"], st.session_state["caption_prompt"], file_embeds_base64) | |
| caption_list_from_claude.append(caption_from_claude) | |
| string_caption_list = str(caption_list_from_claude) | |
| st.session_state["caption_product"] = string_caption_list_product | |
| st.session_state["additional_caption"] = string_caption_list | |
| additional_info_param_for_prompt = f"Brand have provided reference images whose details are:" \ | |
| f"```{string_caption_list}```. Apart from this brand needs" \ | |
| f"{st.session_state['shoot_type']}" | |
| product_info = str(string_caption_list_product) + st.session_state['product_info'] | |
| updated_prompt_for_claude = st.session_state["system_prompt"].replace( | |
| "{{BRAND_DETAILS}}", str(st.session_state['brand_summary'])).replace( | |
| "{{PRODUCT_DETAILS}}", str(product_info)).replace( | |
| "{{ADDITIONAL_INFO}}", str(additional_info_param_for_prompt) | |
| ) | |
| print(f"UP PROMPT:{updated_prompt_for_claude}") | |
| st.session_state["updated_prompt"] = updated_prompt_for_claude | |
| message_schema_for_claude = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": updated_prompt_for_claude | |
| } | |
| ] | |
| } | |
| ] | |
| response_from_claude = create_claude_request_for_text_completion(message_schema_for_claude) | |
| campaign_pattern = r"<campaign_idea>(.*?)</campaign_idea>" | |
| campaigns = re.findall(campaign_pattern, response_from_claude, re.DOTALL) | |
| concat_prompt_list = [] | |
| for idx, campaign in enumerate(campaigns, start=1): | |
| get_model_prompt = extract_data_from_text_xml(campaign, "model_prompt") | |
| get_background_prompt = extract_data_from_text_xml(campaign, "background_prompt") | |
| if get_model_prompt and get_background_prompt: # Ensure both prompts exist | |
| # Clean and concatenate the prompts | |
| concat_prompt_flux = (get_model_prompt.strip() + " " + get_background_prompt.strip()).strip() | |
| concat_prompt_list.append(concat_prompt_flux) | |
| flux_generated_theme_image = [] | |
| for concat_prompt in concat_prompt_list: | |
| theme_image = flux_generated_image(concat_prompt) | |
| flux_generated_theme_image.append(theme_image["file_name"]) | |
| # Debugging: print generated image file names | |
| # print(flux_generated_theme_image) | |
| # Store the session state | |
| st.session_state["descriptions"] = concat_prompt_list | |
| st.session_state["claude_context"] = response_from_claude | |
| st.session_state["images"] = flux_generated_theme_image | |
| if feedback: | |
| updated_context = st.session_state["claude_context"] | |
| if 'images' in st.session_state and 'descriptions' in st.session_state: | |
| for image_path in st.session_state['images']: | |
| os.remove(image_path) | |
| del st.session_state['images'] | |
| del st.session_state['descriptions'] | |
| del st.session_state["claude_context"] | |
| message_schema_for_claude = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": st.session_state["updated_prompt"] | |
| } | |
| ] | |
| }, | |
| { | |
| "role": "assistant", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": updated_context} | |
| ] | |
| }, | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": feedback | |
| } | |
| ] | |
| }, | |
| ] | |
| response_from_claude = create_claude_request_for_text_completion(message_schema_for_claude) | |
| campaign_pattern = r"<campaign_idea>(.*?)</campaign_idea>" | |
| campaigns = re.findall(campaign_pattern, response_from_claude, re.DOTALL) | |
| concat_prompt_list = [] | |
| for idx, campaign in enumerate(campaigns, start=1): | |
| get_model_prompt = extract_data_from_text_xml(campaign, "model_prompt") | |
| get_background_prompt = extract_data_from_text_xml(campaign, "background_prompt") | |
| if get_model_prompt and get_background_prompt: # Ensure both prompts exist | |
| # Clean and concatenate the prompts | |
| concat_prompt_flux = (get_model_prompt.strip() + " " + get_background_prompt.strip()).strip() | |
| concat_prompt_list.append(concat_prompt_flux) | |
| flux_generated_theme_image = [] | |
| for concat_prompt in concat_prompt_list: | |
| theme_image = flux_generated_image(concat_prompt) | |
| flux_generated_theme_image.append(theme_image["file_name"]) | |
| # Debugging: print generated image file names | |
| # print(flux_generated_theme_image) | |
| # Store the session state | |
| st.session_state["descriptions"] = concat_prompt_list | |
| st.session_state["claude_context"] = response_from_claude | |
| st.session_state["images"] = flux_generated_theme_image | |
| selected_image_index = None | |
| cols = st.columns(4) | |
| for i in range(len(st.session_state["images"])): | |
| with cols[i]: | |
| st.image(st.session_state.images[i], caption=st.session_state.descriptions[i], use_column_width=True) | |
| if st.radio(f"Select {i + 1}", [f"Select Image {i + 1}"], key=f"radio_{i}"): | |
| selected_image_index = i | |
| if selected_image_index is not None and st.button("Refine"): | |
| st.session_state.selected_image_index = selected_image_index | |
| st.session_state.selected_image = st.session_state.images[selected_image_index] | |
| st.session_state.selected_text = st.session_state.descriptions[selected_image_index] | |
| st.session_state['page'] = "Page 4" | |
| if st.button("Go Back!"): | |
| st.session_state.page = "Page 2" | |
| def page4(): | |
| import json | |
| selected_theme_text_by_user = st.session_state.descriptions[st.session_state.selected_image_index] | |
| print(selected_theme_text_by_user) | |
| with (st.sidebar): | |
| st.title(st.session_state["product_info"]) | |
| st.write("Product Image") | |
| st.image(st.session_state['uploaded_files']) | |
| st.text("Scene Suggestion:") | |
| st.image(st.session_state.selected_image) | |
| dimensions = st.text_input("Enter Dimensions e.g 3:4, 1:2", key="Dimensions") | |
| seed = st.selectbox( | |
| "Seed Preference", | |
| ("Fixed", "Random"), | |
| ) | |
| if seed == "Fixed": | |
| seed_number = st.number_input("Enter an integer:", min_value=1, max_value=100000, value=10, step=1) | |
| else: | |
| seed_number = 0 | |
| st.text("Thanks will take care") | |
| model__bg_preference = st.text_area("Edit Model & BG Idea", value=selected_theme_text_by_user, | |
| key="Model & BG Idea") | |
| start_chat = st.button("Start Chat") | |
| if "mood_chat_messages" not in st.session_state: | |
| st.session_state["mood_chat_messages"] = [] | |
| if seed and dimensions and model__bg_preference: | |
| if start_chat: | |
| if seed == "Fixed": | |
| generated_flux_image = flux_generated_image_seed(model__bg_preference, seed_number, dimensions) | |
| else: | |
| generated_flux_image = flux_generated_image(model__bg_preference) | |
| st.session_state["mood_chat_messages"].append({ | |
| "role": "AI", | |
| "message": model__bg_preference, | |
| "image": generated_flux_image["file_name"] | |
| }) | |
| # for message in st.session_state["mood_chat_messages"]: | |
| # if message["role"] == "AI": | |
| # st.write(f"Caimera AI: {message['message']}") | |
| # st.image(message['image']) | |
| #else: | |
| # st.write(f"**You**: {message['message']}") | |
| user_input = st.chat_input("Type your message here...") | |
| if user_input: | |
| st.session_state["mood_chat_messages"].append({"role": "User", "message": user_input}) | |
| updated_flux_prompt = prompts["PROMPT_TO_UPDATE_IDEA_OR_MOOD"].format( | |
| EXISTING_MODEL_BG_PROMPT=model__bg_preference, | |
| USER_INSTRUCTIONS=user_input | |
| ) | |
| message_schema_for_claude = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": updated_flux_prompt | |
| } | |
| ] | |
| }, | |
| { | |
| "role": "assistant", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": str(st.session_state["mood_chat_messages"])} | |
| ] | |
| }, | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": user_input + "Reference of previous conversation is also added." | |
| } | |
| ] | |
| }, | |
| ] | |
| response_from_claude = create_claude_request_for_text_completion(message_schema_for_claude) | |
| cleaned_prompt = extract_data_from_text_xml(response_from_claude, "updated_prompt") | |
| if seed == "Fixed": | |
| generated_flux_image_n = flux_generated_image_seed(cleaned_prompt, seed_number, | |
| dimensions) | |
| else: | |
| generated_flux_image_n = flux_generated_image(cleaned_prompt) | |
| st.session_state["mood_chat_messages"].append({ | |
| "role": "AI", | |
| "message": cleaned_prompt, | |
| "actual_response": response_from_claude, | |
| "image": generated_flux_image_n["file_name"] | |
| }) | |
| for message in st.session_state["mood_chat_messages"]: | |
| if message["role"] == "AI": | |
| st.write(f"**AI**: {message['message']}") | |
| st.image(message['image']) | |
| else: | |
| st.write(f"**You**: {message['message']}") | |
| print(seed_number) | |
| if 'page' not in st.session_state: | |
| st.session_state.page = "Page 5" | |
| if st.session_state.page == "Page 5": | |
| page5() | |
| if st.session_state.page == "Page 1": | |
| page1() | |
| elif st.session_state.page == "Page 2": | |
| page2() | |
| elif st.session_state.page == "Page 3": | |
| page3() | |
| elif st.session_state.page == "Page 4": | |
| page4() | |