From 96ffcae9531b247bc5ad02446acb94e3efd543cb Mon Sep 17 00:00:00 2001 From: Hayden Johnson Date: Thu, 10 Apr 2025 22:19:01 -0700 Subject: [PATCH] feat: Begin refactor into class structure This commit introduces a significant refactor of the AI assistant's codebase, transitioning from a procedural approach to an object-oriented design. This change improves modularity, maintainability, and extensibility. Key changes include: - Introduction of classes: AIAssistant, CommandLineParser, InputHandler, and CommandParser. - Encapsulation of core functionality within the AIAssistant class. - Improved input handling with separation of interactive and piped input. - Implementation of command parsing for actions like saving, clearing, and exiting. - Refactored history management for persistent conversation storage. --- assistant.py | 23 +++-- refactored.py | 260 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 275 insertions(+), 8 deletions(-) create mode 100644 refactored.py diff --git a/assistant.py b/assistant.py index 94bf203..3fb812f 100755 --- a/assistant.py +++ b/assistant.py @@ -91,25 +91,30 @@ def highlight_code(language_name, code): lexer = get_lexer_by_name(lexer_name) except ValueError: # If the lexer is not found, guess it - lexer = guess_lexer(code.split('\n')[1:-1]) + lexer = guess_lexer('\n'.join(code.split('\n')[1:-1])) if not lexer: # If no lexer is guessed, default to bash lexer = get_lexer_by_name('bash') else: # If no language is specified, guess the lexer - lexer = guess_lexer('\n'.join(code.split('\n')[1:-1])) - if not lexer: - # If no lexer is guessed, default to bash + try: + lexer = guess_lexer('\n'.join(code.split('\n')[1:-1])) + if not lexer: + # If no lexer is guessed, default to bash + lexer = get_lexer_by_name('bash') + except: lexer = get_lexer_by_name('bash') formatter = TerminalFormatter() - just_code = code.split('\n')[0] + newlines = '\n'.join(code.split('\n')[1:]) # if code is a code block, strip surrounding block markers lines = code.split('\n') if (len(lines) > 2) and ('```' in lines[0]) and ('```' in lines[-1]): just_code = '\n'.join(code.split('\n')[1:-1]) + else: + just_code = code.split('\n')[0] # Inline code highlighted_code = pygments.highlight(just_code, lexer, formatter) return highlighted_code + newlines @@ -338,10 +343,12 @@ def improved_input(): Returns the complete input text when the user indicates they're done. """ lines = [] + print("> ", end='', flush=True) while True: try: line = sys.stdin.readline() if not line: + print("", flush=True) # Some visual feedback that user input has ended break # EOF if line.strip() == '': continue # ignore empty lines @@ -389,8 +396,8 @@ def handle_non_piped_input(args): except (EOFError, KeyboardInterrupt): print("\nExiting...") break - except Exception as e: - print(f"Error: {e}") + # except Exception as e: + # print(f"Error: {e}") client = None @@ -409,7 +416,7 @@ def main(): temp = float(args.temp) if args.context: global num_ctx - num_ctx = float(args.context) + num_ctx = int(args.context) if args.new: global history history = [system_prompt] diff --git a/refactored.py b/refactored.py new file mode 100644 index 0000000..a607b1f --- /dev/null +++ b/refactored.py @@ -0,0 +1,260 @@ +from ollama import Client +import re +import pyperclip +import sys +import argparse +import pygments +from pygments.lexers import get_lexer_by_name, guess_lexer +from pygments.formatters import TerminalFormatter +import os +import json + +class AIAssistant: + def __init__(self, server="http://localhost:11434", model="gemma3:12b"): + self.server = server + self.model = model + self.client = Client(host=self.server) + self.temperature = 0.2 + self.num_ctx = 4096 + self.history = [self.system_prompt()] + self.load_history() + + def set_host(self, host): + self.server = host + self.client = Client(host=host) + + def system_prompt(self): + return {"role": "system", "content": "You are a helpful, smart, kind, and efficient AI assistant. You always fulfill the user's requests accurately and concisely."} + + def load_history(self): + path = os.environ.get('HOME') + '/.cache/ai-assistant.history' + try: + with open(path, 'r') as f: + self.history = json.load(f) + except FileNotFoundError: + pass + + def save_history(self): + path = os.environ.get('HOME') + '/.cache/ai-assistant.history' + with open(path, 'w+') as f: + json.dump(self.history, f) + + def chat(self, message, stream=True): + self.history.append({"role": "user", "content": message}) + completion = self.client.chat( + model=self.model, + options={"temperature": self.temperature, "num_ctx": self.num_ctx}, + messages=self.history, + stream=stream + ) + result = '' + large_chunk = [] + for chunk in completion: + text = chunk['message']['content'] + large_chunk.append(text) + if stream: + print(text, end='', flush=True) + if not stream: + result = completion['message']['content'] + else: + result = ''.join(large_chunk) + self.history.append({"role": 'assistant', 'content': result}) + self.save_history() + return result + +class CommandLineParser: + def __init__(self): + self.parser = argparse.ArgumentParser(description='Chat with an intelligent assistant') + self.add_arguments() + + def add_arguments(self): + parser = self.parser + parser.add_argument('--host', nargs='?', const=True, default=False, help='Specify host of Ollama server') + parser.add_argument('--model', '-m', nargs='?', const=True, default=False, help='Specify model') + parser.add_argument('--temp', '-t', nargs='?', type=float, const=0.2, default=False, help='Specify temperature') + parser.add_argument('--context', type=int, default=4096, help='Specify context size') + parser.add_argument('--reasoning', '-r', action='store_true', help='Use the default reasoning model deepseek-r1:14b') + parser.add_argument('--new', '-n', action='store_true', help='Start a chat with a fresh history') + parser.add_argument('--follow-up', '-f', nargs='?', const=True, default=False, help='Ask a follow up question when piping in context') + parser.add_argument('--copy', '-c', action='store_true', help='Copy a codeblock if it appears') + parser.add_argument('--shell', '-s', nargs='?', const=True, default=False, help='Output a shell command that does as described') + + def parse(self): + return self.parser.parse_args() + +class InputHandler: + def __init__(self, assistant): + self.assistant = assistant + + def handle_input(self, args): + if not sys.stdin.isatty(): + self.handle_piped_input(args) + else: + self.handle_interactive_input(args) + + def handle_piped_input(self, args): + all_input = sys.stdin.read() + query = f'Use the following context to answer the question. There will be no follow up questions from the user so make sure your answer is complete:\n{all_input}\n' + if args.copy: + query += 'Answer the question using a codeblock for any code or shell scripts\n' + if args.follow_up: + second_input = input('> ') + query += f'\n{second_input}' + result = self.assistant.chat(query, stream=False) + blocks = self.extract_code_block(result) + if args.copy and len(blocks): + pyperclip.copy(blocks[0]) + + def handle_interactive_input(self, args): + print("\033[91massistant\033[0m: Type your message (press Ctrl+D to send):") + while True: + try: + full_input = self.improved_input() + if full_input is None: + break + if full_input.strip() == '': + continue + print("\033[91massistant\033[0m: ", end='') + result = self.assistant.chat(full_input) + print() + except (EOFError, KeyboardInterrupt): + print("\nExiting...") + break + + def improved_input(self): + lines = [] + print("> ", end='', flush=True) + while True: + try: + line = sys.stdin.readline() + if not line: + print(flush=True) + break + if line.strip() == '': + continue + lines.append(line) + if len(lines) > 1 and lines[-2] == '\n' and lines[-1] == '\n': + break + except KeyboardInterrupt: + print("\nUser aborted input") + return None + full_input = ''.join(lines).rstrip('\n') + return full_input + + def extract_code_block(self, text): + pattern = r'```[a-z]*\n[\s\S]*?\n```' + code_blocks = [] + matches = re.finditer(pattern, text) + for match in matches: + code_block = match.group(0) + lexer_name = self.determine_lexer(code_block) + highlighted_code = self.highlight_code(lexer_name, code_block) + code_blocks.append(highlighted_code) + if not code_blocks: + line_pattern = r'`[a-z]*[\s\S]*?`' + line_matches = re.finditer(line_pattern, text) + for match in line_matches: + code_block = match.group(0) + code_blocks.append(code_block[1:-1]) + return code_blocks + + def determine_lexer(self, code_block): + lexer_name = None + lines = code_block.split('\n') + for line in lines: + if line.strip().startswith('```'): + lexer_part = line.strip().split('```')[1].strip() + if lexer_part: + lexer_name = lexer_part + break + elif line.strip().startswith('lang:'): + lexer_part = line.strip().split(':')[1].strip() + if lexer_part: + lexer_name = lexer_part + break + return lexer_name + + def highlight_code(self, lexer_name, code): + try: + lexer = get_lexer_by_name(lexer_name) if lexer_name else guess_lexer(code) + except ValueError: + lexer = guess_lexer('\n'.join(code.split('\n')[1:-1])) + if not lexer: + lexer = get_lexer_by_name('bash') + formatter = TerminalFormatter() + highlighted_code = pygments.highlight(code, lexer, formatter) + return highlighted_code + +class CommandParser: + def __init__(self): + self.commands = { + '/save': self.handle_save, + '/clear': self.handle_clear, + '/clipboard': self.handle_clipboard, + '/exit': self.handle_exit + } + + def parse_commands(self, text): + tokens = text.split(' ') + if not tokens: + return False + command = tokens[0] + if command in self.commands: + handler = self.commands[command] + if len(tokens) > 1 and command == '/clipboard': + context_query = '\n\nThe following is context provided by the user:\n' + clipboard_content = pyperclip.paste() + if clipboard_content: + context_query += clipboard_content + '\n' + return handler(context_query) + else: + handler() + return True + return False + + def handle_save(self): + filename = input('Enter filename to save conversation: ') + self.save_conversation(filename) + + def save_conversation(self, filename='conversation.md'): + if not filename.endswith('.md'): + filename += '.md' + base, extension = os.path.splitext(filename) + i = 1 + while os.path.exists(filename): + filename = f"{base}_{i}{extension}" + i += 1 + with open(filename, 'w') as f: + f.write(conversation) + + def handle_clear(self): + self.assistant.history = [self.assistant.system_prompt()] + self.assistant.save_history() + + def handle_clipboard(self, context_query): + # Implementation for clipboard command + pass + + def handle_exit(self): + sys.exit(0) + +def main(): + args = CommandLineParser().parse() + assistant = AIAssistant() + if args.host: + assistant.set_host(args.host) + if args.model: + assistant.model = args.model + if args.temp: + assistant.temperature = args.temp + if args.context: + assistant.num_ctx = args.context + if args.new: + assistant.history = [assistant.system_prompt()] + assistant.save_history() + + input_handler = InputHandler(assistant) + input_handler.handle_input(args) + +if __name__ == '__main__': + main()