commit 35f30a45139654a2607cb7c4beacd9b278b066ed Author: callmeyan Date: Thu Dec 14 16:19:54 2023 +0800 init customize base on streamlit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3e3726c --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +# Default ignored files +venv +.idea \ No newline at end of file diff --git a/customize.css b/customize.css new file mode 100644 index 0000000..e69de29 diff --git a/main.py b/main.py new file mode 100644 index 0000000..5596b44 --- /dev/null +++ b/main.py @@ -0,0 +1,16 @@ +# This is a sample Python script. + +# Press Shift+F10 to execute it or replace it with your code. +# Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings. + + +def print_hi(name): + # Use a breakpoint in the code line below to debug your script. + print(f'Hi, {name}') # Press Ctrl+F8 to toggle the breakpoint. + + +# Press the green button in the gutter to run the script. +if __name__ == '__main__': + print_hi('PyCharm') + +# See PyCharm help at https://www.jetbrains.com/help/pycharm/ diff --git a/origin.py b/origin.py new file mode 100644 index 0000000..fb7d5c6 --- /dev/null +++ b/origin.py @@ -0,0 +1,89 @@ +import os +import streamlit as st +import torch +from transformers import AutoModel, AutoTokenizer + +MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b') +TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH) +DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' +# 设置页面标题、图标和布局 +st.set_page_config( + page_title="ChatGLM3-6B 演示", + page_icon=":robot:", + layout="wide" +) + +@st.cache_resource +def get_model(): + tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True) + if 'cuda' in DEVICE: # AMD, NVIDIA GPU can use Half Precision + model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True).to(DEVICE).eval() + else: # CPU, Intel GPU and other GPU can use Float16 Precision Only + model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True).float().to(DEVICE).eval() + # 多显卡支持,使用下面两行代替上面一行,将num_gpus改为你实际的显卡数量 + # from utils import load_model_on_gpus + # model = load_model_on_gpus("THUDM/chatglm3-6b", num_gpus=2) + return tokenizer, model + +# 加载Chatglm3的model和tokenizer +tokenizer, model = get_model() + +# 初始化历史记录和past key values +if "history" not in st.session_state: + st.session_state.history = [] +if "past_key_values" not in st.session_state: + st.session_state.past_key_values = None + +# 设置max_length、top_p和temperature +max_length = st.sidebar.slider("max_length", 0, 32768, 8192, step=1) +top_p = st.sidebar.slider("top_p", 0.0, 1.0, 0.8, step=0.01) +temperature = st.sidebar.slider("temperature", 0.0, 1.0, 0.6, step=0.01) + +# 清理会话历史 +buttonClean = st.sidebar.button("清理会话历史", key="clean") +if buttonClean: + st.session_state.history = [] + st.session_state.past_key_values = None + if torch.cuda.is_available(): + torch.cuda.empty_cache() + st.rerun() + +# 渲染聊天历史记录 +for i, message in enumerate(st.session_state.history): + if message["role"] == "user": + with st.chat_message(name="user", avatar="user"): + st.markdown(message["content"]) + else: + with st.chat_message(name="assistant", avatar="assistant"): + st.markdown(message["content"]) + +# 输入框和输出框 +with st.chat_message(name="user", avatar="user"): + input_placeholder = st.empty() +with st.chat_message(name="assistant", avatar="assistant"): + message_placeholder = st.empty() + +# 获取用户输入 +prompt_text = st.chat_input("请输入您的问题") + +# 如果用户输入了内容,则生成回复 +if prompt_text: + + input_placeholder.markdown(prompt_text) + history = st.session_state.history + past_key_values = st.session_state.past_key_values + for response, history, past_key_values in model.stream_chat( + tokenizer, + prompt_text, + history, + past_key_values=past_key_values, + max_length=max_length, + top_p=top_p, + temperature=temperature, + return_past_key_values=True, + ): + message_placeholder.markdown(response) + + # 更新历史记录和past key values + st.session_state.history = history + st.session_state.past_key_values = past_key_values \ No newline at end of file diff --git a/run.bat b/run.bat new file mode 100644 index 0000000..925e740 --- /dev/null +++ b/run.bat @@ -0,0 +1 @@ +streamlit run web_demo.py \ No newline at end of file diff --git a/web_demo.py b/web_demo.py new file mode 100644 index 0000000..66d8d7e --- /dev/null +++ b/web_demo.py @@ -0,0 +1,105 @@ +import streamlit as st + +st.set_page_config( + page_title="web-demo", + page_icon=":web-demo:", + layout="wide", +) +st.markdown(""" """, + unsafe_allow_html=True) +# 初始化历史记录和past key values +if "history" not in st.session_state: + st.session_state.history = [] +if "past_key_values" not in st.session_state: + st.session_state.past_key_values = None + +# 设置max_length、top_p和temperature +# max_length = st.sidebar.slider("max_length", 0, 32768, 8192, step=1) +# top_p = st.sidebar.slider("top_p", 0.0, 1.0, 0.8, step=0.01) +# temperature = st.sidebar.slider("temperature", 0.0, 1.0, 0.6, step=0.01) + +# 清理会话历史 +# buttonClean = st.sidebar.button("清理会话历史", key="clean") +# if buttonClean: +# st.session_state.history = [] +# st.session_state.past_key_values = None +# st.rerun() + +mainContainer = st.container() +if len(st.session_state.history) == 0: + emptyContent = st.markdown("""
+
+
+
+
+
例子
+
用简单的术语解释质量守恒定律
+
对 10 岁生日有什么创意吗?
+
+ 如何在 Javascript 中发出 HTTP 请求?request in Javascript? +
+
+
+
能力
+
能记住用户早些时候的对话
+
允许用户提供后续更正
+
会拒绝不当的请求
+
+
+
提醒
+
少数可能会产生错误信息
+
可能会产生不当的指令或有偏见的内容
+
对 2021 年后的世界和事件的了解有限
+
+
+
+
""", unsafe_allow_html=True) +else: + # 渲染聊天历史记录 + for i, message in enumerate(st.session_state.history): + if message["role"] == "user": + with mainContainer.chat_message(name="user", avatar="user"): + st.markdown(message["content"]) + else: + with mainContainer.chat_message(name="assistant", avatar="assistant"): + st.markdown(message["content"]) + + # 输入框和输出框 + with mainContainer.chat_message(name="user", avatar="user"): + input_placeholder = st.empty() + with mainContainer.chat_message(name="assistant", avatar="assistant"): + message_placeholder = st.empty() +# 获取用户输入 +prompt_text = st.chat_input("请输入您的问题") + +# 如果用户输入了内容,则生成回复 +if prompt_text: + # 处理历史为空的情况 + if len(st.session_state.history) == 0: + emptyContent.empty() + # 输入框和输出框 + with mainContainer.chat_message(name="user", avatar="user"): + input_placeholder = st.empty() + with mainContainer.chat_message(name="assistant", avatar="assistant"): + message_placeholder = st.empty() + + history = st.session_state.history + past_key_values = st.session_state.past_key_values + history.append({ + 'content': prompt_text, "role": 'user' + }) + response = "您的问题:”" + prompt_text + "“,我暂时无法处理!请重新问一个问题吧!" + history.append({ + 'content': response, "role": 'system' + }) + input_placeholder.markdown(prompt_text) + message_placeholder.markdown(response) + # 更新历史记录和past key values + st.session_state.history = history + st.session_state.past_key_values = past_key_values + +# st.markdown(""" +# +# """, unsafe_allow_html=True)