This commit is contained in:
raykkk
2025-10-17 21:40:45 +08:00
commit 7d0451131f
155 changed files with 14873 additions and 0 deletions

View File

@@ -0,0 +1,467 @@
# -*- coding: utf-8 -*-
import json
import logging
import os
from datetime import datetime
from typing import Tuple, Optional, Union, Dict, Any, Generator
import requests
from dotenv import load_dotenv
from flask import Flask, jsonify, request
from flask_cors import CORS
from flask_sqlalchemy import SQLAlchemy
from werkzeug.security import check_password_hash, generate_password_hash
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
load_dotenv()
app = Flask(__name__)
CORS(app, resources={r"/*": {"origins": "*"}})
# Configure database
basedir = os.path.abspath(os.path.dirname(__file__))
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///" + os.path.join(
basedir,
"ai_assistant.db",
)
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
db: SQLAlchemy = SQLAlchemy(app)
# Database models
class User(db.Model):
id = db.Column(db.Integer, primary_key=True)
username = db.Column(db.String(80), unique=True, nullable=False)
password_hash = db.Column(db.String(120), nullable=False)
name = db.Column(db.String(100), nullable=False)
created_at = db.Column(db.DateTime, default=datetime.utcnow)
# Relationships
conversations = db.relationship(
"Conversation",
backref="user",
lazy=True,
cascade="all, delete-orphan",
)
def set_password(self, password: str) -> None:
self.password_hash = generate_password_hash(password)
def check_password(self, password: str) -> bool:
return check_password_hash(self.password_hash, password)
class Conversation(db.Model):
id = db.Column(db.Integer, primary_key=True)
title = db.Column(db.String(200), nullable=False)
user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False)
created_at = db.Column(db.DateTime, default=datetime.utcnow)
updated_at = db.Column(
db.DateTime,
default=datetime.utcnow,
onupdate=datetime.utcnow,
)
# Relationships
messages = db.relationship(
"Message",
backref="conversation",
lazy=True,
cascade="all, delete-orphan",
)
class Message(db.Model):
id = db.Column(db.Integer, primary_key=True)
text = db.Column(db.Text, nullable=False)
sender = db.Column(db.String(20), nullable=False) # 'user' or 'ai'
conversation_id = db.Column(
db.Integer,
db.ForeignKey("conversation.id"),
nullable=False,
)
created_at = db.Column(db.DateTime, default=datetime.utcnow)
# Create database tables
def create_tables() -> None:
db.create_all()
# Create sample users (if none exist)
if not User.query.first():
user1 = User(username="user1", name="Bruce")
user1.set_password("password123")
user2 = User(username="user2", name="John")
user2.set_password("password456")
db.session.add(user1)
db.session.add(user2)
db.session.commit()
# functions
def parse_sse_line(
line: bytes,
) -> Tuple[Optional[str], Optional[Union[str, int]]]:
line = line.decode("utf-8").strip()
if line.startswith("data: "):
return "data", line[6:]
elif line.startswith("event:"):
return "event", line[7:]
elif line.startswith("id: "):
return "id", line[4:]
elif line.startswith("retry:"):
return "retry", int(line[7:])
return None, None
def sse_client(
url: str,
data: Optional[Dict[str, Any]] = None,
) -> Generator[str, None, None]:
headers = {
"Accept": "text/event-stream",
"Cache-Control": "no-cache",
}
if data is not None:
response = requests.post(
url,
stream=True,
headers=headers,
json=data,
)
else:
response = requests.get(
url,
stream=True,
headers=headers,
)
for line in response.iter_lines():
if line:
field, value = parse_sse_line(line)
if field == "data":
try:
data = json.loads(value)
if (
data["object"] == "content"
and data["delta"] is True
and data["type"] == "text"
):
yield data["text"]
except json.JSONDecodeError:
pass
def call_runner(
query: str,
query_user_id: str,
query_session_id: str,
) -> Generator[str, None, None]:
server_port = int(os.environ.get("SERVER_PORT", "8090"))
server_endpoint = os.environ.get("SERVER_ENDPOINT", "agent")
server_host = os.environ.get("SERVER_HOST", "localhost")
url = f"http://{server_host}:{server_port}/{server_endpoint}"
data_arg: Dict[str, Any] = {
"input": [
{
"role": "user",
"content": [
{
"type": "text",
"text": query,
},
],
},
],
"session_id": query_session_id,
"user_id": query_user_id,
}
for content in sse_client(url, data=data_arg):
yield content
# API routes
# User login
@app.route("/api/login", methods=["POST"])
def login():
data = request.get_json()
username = data.get("username")
password = data.get("password")
if not username or not password:
return jsonify({"error": "Username and password cannot be empty"}), 400
user = User.query.filter_by(username=username).first()
if user and user.check_password(password):
return (
jsonify(
{
"id": user.id,
"username": user.username,
"name": user.name,
"created_at": user.created_at.isoformat(),
},
),
200,
)
else:
return jsonify({"error": "Invalid username or password"}), 401
# Get all user conversations
@app.route("/api/users/<int:user_id>/conversations", methods=["GET"])
def get_user_conversations(user_id):
User.query.get_or_404(user_id)
conversations = (
Conversation.query.filter_by(user_id=user_id)
.order_by(
Conversation.updated_at.desc(),
)
.all()
)
result = []
for conv in conversations:
# Get the last message as preview
last_message = (
Message.query.filter_by(
conversation_id=conv.id,
)
.order_by(
Message.created_at.desc(),
)
.first()
)
preview = last_message.text if last_message else ""
result.append(
{
"id": conv.id,
"title": conv.title,
"user_id": conv.user_id,
"preview": preview,
"created_at": conv.created_at.isoformat(),
"updated_at": conv.updated_at.isoformat(),
},
)
return jsonify(result), 200
# Create new conversation
@app.route("/api/users/<int:user_id>/conversations", methods=["POST"])
def create_conversation(user_id):
User.query.get_or_404(user_id)
data = request.get_json()
title = data.get(
"title",
f'Conversation {datetime.now().strftime("%Y-%m-%d %H:%M")}',
)
conversation = Conversation(title=title, user_id=user_id)
db.session.add(conversation)
db.session.commit()
# Create welcome message
welcome_message = Message(
text="Hello! I am your AI assistant. How can I help you today?",
sender="ai",
conversation_id=conversation.id,
)
db.session.add(welcome_message)
db.session.commit()
return (
jsonify(
{
"id": conversation.id,
"title": conversation.title,
"user_id": conversation.user_id,
"created_at": conversation.created_at.isoformat(),
"updated_at": conversation.updated_at.isoformat(),
},
),
201,
)
# Get conversation details and messages
@app.route("/api/conversations/<int:conversation_id>", methods=["GET"])
def get_conversation(conversation_id):
conversation = Conversation.query.get_or_404(conversation_id)
messages = (
Message.query.filter_by(
conversation_id=conversation_id,
)
.order_by(
Message.created_at.asc(),
)
.all()
)
messages_data = []
for msg in messages:
messages_data.append(
{
"id": msg.id,
"text": msg.text,
"sender": msg.sender,
"created_at": msg.created_at.isoformat(),
},
)
return (
jsonify(
{
"id": conversation.id,
"title": conversation.title,
"user_id": conversation.user_id,
"messages": messages_data,
"created_at": conversation.created_at.isoformat(),
"updated_at": conversation.updated_at.isoformat(),
},
),
200,
)
# Send message
@app.route(
"/api/conversations/<int:conversation_id>/messages",
methods=["POST"],
)
def send_message(conversation_id):
conversation = Conversation.query.get_or_404(conversation_id)
data = request.get_json()
text = data.get("text")
sender = data.get("sender", "user")
if not text:
return jsonify({"error": "Message content cannot be empty"}), 400
# Create user message
user_message = Message(
text=text,
sender=sender,
conversation_id=conversation_id,
)
db.session.add(user_message)
# Update conversation title (if it's the first user message)
if sender == "user" and len(conversation.messages) <= 1:
conversation.title = text[:20] + ("..." if len(text) > 20 else "")
db.session.commit()
if sender == "user":
ai_response_text = ""
question = text
conversation_id_str = str(conversation_id)
for item in call_runner(
question,
conversation_id_str,
conversation_id_str,
):
ai_response_text += item
ai_message = Message(
text=ai_response_text,
sender="ai",
conversation_id=conversation_id,
)
db.session.add(ai_message)
db.session.commit()
return (
jsonify(
{
"id": user_message.id,
"text": user_message.text,
"sender": user_message.sender,
"created_at": user_message.created_at.isoformat(),
},
),
201,
)
# Delete conversation
@app.route("/api/conversations/<int:conversation_id>", methods=["DELETE"])
def delete_conversation(conversation_id):
conversation = Conversation.query.get_or_404(conversation_id)
db.session.delete(conversation)
db.session.commit()
return jsonify({"message": "Conversation deleted successfully"}), 200
# Update conversation title
@app.route("/api/conversations/<int:conversation_id>", methods=["PUT"])
def update_conversation(conversation_id):
conversation = Conversation.query.get_or_404(conversation_id)
data = request.get_json()
if "title" in data:
conversation.title = data["title"]
db.session.commit()
return (
jsonify(
{
"id": conversation.id,
"title": conversation.title,
"user_id": conversation.user_id,
"created_at": conversation.created_at.isoformat(),
"updated_at": conversation.updated_at.isoformat(),
},
),
200,
)
# Get user information
@app.route("/api/users/<int:user_id>", methods=["GET"])
def get_user(user_id):
user = User.query.get_or_404(user_id)
return (
jsonify(
{
"id": user.id,
"username": user.username,
"name": user.name,
"created_at": user.created_at.isoformat(),
},
),
200,
)
# Error handling
@app.errorhandler(404)
def not_found(error):
logger.error(error)
return jsonify({"error": "Resource not found"}), 404
@app.errorhandler(500)
def internal_error(error):
logger.error(error)
db.session.rollback()
return jsonify({"error": "Internal server error"}), 500
if __name__ == "__main__":
app.run(debug=True, host="0.0.0.0", port=5100)