diff --git a/assistant.py b/assistant.py index c8a6ef3..b6b93c5 100644 --- a/assistant.py +++ b/assistant.py @@ -3,6 +3,8 @@ import argparse import os import sys +import datetime +import sqlite3 import json from ollama import Client @@ -17,6 +19,7 @@ from prompt_toolkit.key_binding import KeyBindings from prompt_toolkit import PromptSession + class AIAssistant: def __init__(self, server="http://localhost:11434", model="qwen3:14b"): self.server = server @@ -25,6 +28,75 @@ class AIAssistant: self.temperature = 0.2 self.num_ctx = 4096 self.history = [self.system_prompt()] + self.db_path = os.path.expanduser("~/.cache/ai-assistant.db") + self._init_db() + + def _init_db(self): + """Initialize SQLite database and create the conversations table.""" + if not os.path.exists(self.db_path): + self._create_db() + + def _create_db(self): + """Create the conversations table in the SQLite database.""" + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + cursor.execute(''' + CREATE TABLE IF NOT EXISTS conversations ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + topic TEXT, + history TEXT, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP + ) + ''') + conn.commit() + conn.close() + + def _save_to_db(self, topic): + """Save the current conversation to the SQLite database.""" + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + cursor.execute(''' + INSERT INTO conversations (topic, history) + VALUES (?, ?) + ''', (topic, json.dumps(self.history))) + conn.commit() + conn.close() + + def _load_from_db(self, conversation_id): + """Load a conversation from the SQLite database by ID.""" + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + cursor.execute(''' + SELECT history FROM conversations WHERE id = ? + ''', (conversation_id,)) + result = cursor.fetchone() + conn.close() + if result: + self.history = json.loads(result[0]) + else: + self.history = [self.system_prompt()] + + def save_history(self): + """Save the current conversation to the database with a generated topic.""" + # Only save if this is the first user message + if len(self.history) == 3: + # Generate a topic using the AI + system_prompt = self.system_prompt() + user_prompt = "Generate a concise, descriptive topic for this conversation based on the following content:\n" + topic = self.client.chat(model=self.model, messages=[system_prompt, {"role": "user", "content": user_prompt}], stream=False)['message']['content'].strip() + self._save_to_db(topic) + else: + # For subsequent messages, we can update the topic in the future + pass + + def load_history(self, conversation_id=None): + """Load a conversation from the database by ID. If no ID, start a new one.""" + if conversation_id: + self._load_from_db(conversation_id) + else: + self.history = [self.system_prompt()] + + def set_host(self, host): self.server = host @@ -33,19 +105,6 @@ class AIAssistant: def system_prompt(self): return {"role": "system", "content": "You are a helpful, smart, kind, and efficient AI assistant. You always fulfill the user's requests accurately and concisely."} - def load_history(self): - path = os.environ.get('HOME') + '/.cache/ai-assistant.history' - try: - with open(path, 'r') as f: - self.history = json.load(f) - except FileNotFoundError: - pass - - def save_history(self): - path = os.environ.get('HOME') + '/.cache/ai-assistant.history' - with open(path, 'w+') as f: - json.dump(self.history, f) - def determine_lexer(self, code_block): lexer_name = None lines = code_block.split('\n') @@ -126,7 +185,7 @@ class CommandLineParser: parser.add_argument('--temp', '-t', nargs='?', type=float, const=0.2, default=False, help='Specify temperature') parser.add_argument('--context', type=int, default=4096, help='Specify context size') parser.add_argument('--reasoning', '-r', action='store_true', help='Use the default reasoning model deepseek-r1:14b') - parser.add_argument('--new', '-n', action='store_true', help='Start a chat with a fresh history') + parser.add_argument('--resume', action='store_true', help='Resume a previous conversation') parser.add_argument('--follow-up', '-f', nargs='?', const=True, default=False, help='Ask a follow up question when piping in context') parser.add_argument('--copy', '-c', action='store_true', help='Copy a codeblock if it appears') parser.add_argument('--shell', '-s', nargs='?', const=True, default=False, help='Output a shell command that does as described') @@ -324,11 +383,10 @@ def main(): assistant.temperature = args.temp if args.context: assistant.num_ctx = args.context - if args.new: - assistant.history = [assistant.system_prompt()] - assistant.save_history() - else: + if args.resume: assistant.load_history() + else: + assistant.history = [assistant.system_prompt()] command_parser = CommandParser() input_handler = InputHandler(assistant, command_parser)