feat: Begin refactor into class structure
This commit introduces a significant refactor of the AI assistant's codebase, transitioning from a procedural approach to an object-oriented design. This change improves modularity, maintainability, and extensibility. Key changes include: - Introduction of classes: AIAssistant, CommandLineParser, InputHandler, and CommandParser. - Encapsulation of core functionality within the AIAssistant class. - Improved input handling with separation of interactive and piped input. - Implementation of command parsing for actions like saving, clearing, and exiting. - Refactored history management for persistent conversation storage.
This commit is contained in:
parent
0dcc44870e
commit
96ffcae953
23
assistant.py
23
assistant.py
|
|
@ -91,25 +91,30 @@ def highlight_code(language_name, code):
|
||||||
lexer = get_lexer_by_name(lexer_name)
|
lexer = get_lexer_by_name(lexer_name)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
# If the lexer is not found, guess it
|
# If the lexer is not found, guess it
|
||||||
lexer = guess_lexer(code.split('\n')[1:-1])
|
lexer = guess_lexer('\n'.join(code.split('\n')[1:-1]))
|
||||||
if not lexer:
|
if not lexer:
|
||||||
# If no lexer is guessed, default to bash
|
# If no lexer is guessed, default to bash
|
||||||
lexer = get_lexer_by_name('bash')
|
lexer = get_lexer_by_name('bash')
|
||||||
else:
|
else:
|
||||||
# If no language is specified, guess the lexer
|
# If no language is specified, guess the lexer
|
||||||
lexer = guess_lexer('\n'.join(code.split('\n')[1:-1]))
|
try:
|
||||||
if not lexer:
|
lexer = guess_lexer('\n'.join(code.split('\n')[1:-1]))
|
||||||
# If no lexer is guessed, default to bash
|
if not lexer:
|
||||||
|
# If no lexer is guessed, default to bash
|
||||||
|
lexer = get_lexer_by_name('bash')
|
||||||
|
except:
|
||||||
lexer = get_lexer_by_name('bash')
|
lexer = get_lexer_by_name('bash')
|
||||||
|
|
||||||
formatter = TerminalFormatter()
|
formatter = TerminalFormatter()
|
||||||
|
|
||||||
just_code = code.split('\n')[0]
|
|
||||||
newlines = '\n'.join(code.split('\n')[1:])
|
newlines = '\n'.join(code.split('\n')[1:])
|
||||||
# if code is a code block, strip surrounding block markers
|
# if code is a code block, strip surrounding block markers
|
||||||
lines = code.split('\n')
|
lines = code.split('\n')
|
||||||
if (len(lines) > 2) and ('```' in lines[0]) and ('```' in lines[-1]):
|
if (len(lines) > 2) and ('```' in lines[0]) and ('```' in lines[-1]):
|
||||||
just_code = '\n'.join(code.split('\n')[1:-1])
|
just_code = '\n'.join(code.split('\n')[1:-1])
|
||||||
|
else:
|
||||||
|
just_code = code.split('\n')[0] # Inline code
|
||||||
|
|
||||||
highlighted_code = pygments.highlight(just_code, lexer, formatter)
|
highlighted_code = pygments.highlight(just_code, lexer, formatter)
|
||||||
return highlighted_code + newlines
|
return highlighted_code + newlines
|
||||||
|
|
@ -338,10 +343,12 @@ def improved_input():
|
||||||
Returns the complete input text when the user indicates they're done.
|
Returns the complete input text when the user indicates they're done.
|
||||||
"""
|
"""
|
||||||
lines = []
|
lines = []
|
||||||
|
print("> ", end='', flush=True)
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
line = sys.stdin.readline()
|
line = sys.stdin.readline()
|
||||||
if not line:
|
if not line:
|
||||||
|
print("", flush=True) # Some visual feedback that user input has ended
|
||||||
break # EOF
|
break # EOF
|
||||||
if line.strip() == '':
|
if line.strip() == '':
|
||||||
continue # ignore empty lines
|
continue # ignore empty lines
|
||||||
|
|
@ -389,8 +396,8 @@ def handle_non_piped_input(args):
|
||||||
except (EOFError, KeyboardInterrupt):
|
except (EOFError, KeyboardInterrupt):
|
||||||
print("\nExiting...")
|
print("\nExiting...")
|
||||||
break
|
break
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
print(f"Error: {e}")
|
# print(f"Error: {e}")
|
||||||
|
|
||||||
client = None
|
client = None
|
||||||
|
|
||||||
|
|
@ -409,7 +416,7 @@ def main():
|
||||||
temp = float(args.temp)
|
temp = float(args.temp)
|
||||||
if args.context:
|
if args.context:
|
||||||
global num_ctx
|
global num_ctx
|
||||||
num_ctx = float(args.context)
|
num_ctx = int(args.context)
|
||||||
if args.new:
|
if args.new:
|
||||||
global history
|
global history
|
||||||
history = [system_prompt]
|
history = [system_prompt]
|
||||||
|
|
|
||||||
260
refactored.py
Normal file
260
refactored.py
Normal file
|
|
@ -0,0 +1,260 @@
|
||||||
|
from ollama import Client
|
||||||
|
import re
|
||||||
|
import pyperclip
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
import pygments
|
||||||
|
from pygments.lexers import get_lexer_by_name, guess_lexer
|
||||||
|
from pygments.formatters import TerminalFormatter
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
class AIAssistant:
|
||||||
|
def __init__(self, server="http://localhost:11434", model="gemma3:12b"):
|
||||||
|
self.server = server
|
||||||
|
self.model = model
|
||||||
|
self.client = Client(host=self.server)
|
||||||
|
self.temperature = 0.2
|
||||||
|
self.num_ctx = 4096
|
||||||
|
self.history = [self.system_prompt()]
|
||||||
|
self.load_history()
|
||||||
|
|
||||||
|
def set_host(self, host):
|
||||||
|
self.server = host
|
||||||
|
self.client = Client(host=host)
|
||||||
|
|
||||||
|
def system_prompt(self):
|
||||||
|
return {"role": "system", "content": "You are a helpful, smart, kind, and efficient AI assistant. You always fulfill the user's requests accurately and concisely."}
|
||||||
|
|
||||||
|
def load_history(self):
|
||||||
|
path = os.environ.get('HOME') + '/.cache/ai-assistant.history'
|
||||||
|
try:
|
||||||
|
with open(path, 'r') as f:
|
||||||
|
self.history = json.load(f)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def save_history(self):
|
||||||
|
path = os.environ.get('HOME') + '/.cache/ai-assistant.history'
|
||||||
|
with open(path, 'w+') as f:
|
||||||
|
json.dump(self.history, f)
|
||||||
|
|
||||||
|
def chat(self, message, stream=True):
|
||||||
|
self.history.append({"role": "user", "content": message})
|
||||||
|
completion = self.client.chat(
|
||||||
|
model=self.model,
|
||||||
|
options={"temperature": self.temperature, "num_ctx": self.num_ctx},
|
||||||
|
messages=self.history,
|
||||||
|
stream=stream
|
||||||
|
)
|
||||||
|
result = ''
|
||||||
|
large_chunk = []
|
||||||
|
for chunk in completion:
|
||||||
|
text = chunk['message']['content']
|
||||||
|
large_chunk.append(text)
|
||||||
|
if stream:
|
||||||
|
print(text, end='', flush=True)
|
||||||
|
if not stream:
|
||||||
|
result = completion['message']['content']
|
||||||
|
else:
|
||||||
|
result = ''.join(large_chunk)
|
||||||
|
self.history.append({"role": 'assistant', 'content': result})
|
||||||
|
self.save_history()
|
||||||
|
return result
|
||||||
|
|
||||||
|
class CommandLineParser:
|
||||||
|
def __init__(self):
|
||||||
|
self.parser = argparse.ArgumentParser(description='Chat with an intelligent assistant')
|
||||||
|
self.add_arguments()
|
||||||
|
|
||||||
|
def add_arguments(self):
|
||||||
|
parser = self.parser
|
||||||
|
parser.add_argument('--host', nargs='?', const=True, default=False, help='Specify host of Ollama server')
|
||||||
|
parser.add_argument('--model', '-m', nargs='?', const=True, default=False, help='Specify model')
|
||||||
|
parser.add_argument('--temp', '-t', nargs='?', type=float, const=0.2, default=False, help='Specify temperature')
|
||||||
|
parser.add_argument('--context', type=int, default=4096, help='Specify context size')
|
||||||
|
parser.add_argument('--reasoning', '-r', action='store_true', help='Use the default reasoning model deepseek-r1:14b')
|
||||||
|
parser.add_argument('--new', '-n', action='store_true', help='Start a chat with a fresh history')
|
||||||
|
parser.add_argument('--follow-up', '-f', nargs='?', const=True, default=False, help='Ask a follow up question when piping in context')
|
||||||
|
parser.add_argument('--copy', '-c', action='store_true', help='Copy a codeblock if it appears')
|
||||||
|
parser.add_argument('--shell', '-s', nargs='?', const=True, default=False, help='Output a shell command that does as described')
|
||||||
|
|
||||||
|
def parse(self):
|
||||||
|
return self.parser.parse_args()
|
||||||
|
|
||||||
|
class InputHandler:
|
||||||
|
def __init__(self, assistant):
|
||||||
|
self.assistant = assistant
|
||||||
|
|
||||||
|
def handle_input(self, args):
|
||||||
|
if not sys.stdin.isatty():
|
||||||
|
self.handle_piped_input(args)
|
||||||
|
else:
|
||||||
|
self.handle_interactive_input(args)
|
||||||
|
|
||||||
|
def handle_piped_input(self, args):
|
||||||
|
all_input = sys.stdin.read()
|
||||||
|
query = f'Use the following context to answer the question. There will be no follow up questions from the user so make sure your answer is complete:\n{all_input}\n'
|
||||||
|
if args.copy:
|
||||||
|
query += 'Answer the question using a codeblock for any code or shell scripts\n'
|
||||||
|
if args.follow_up:
|
||||||
|
second_input = input('> ')
|
||||||
|
query += f'\n{second_input}'
|
||||||
|
result = self.assistant.chat(query, stream=False)
|
||||||
|
blocks = self.extract_code_block(result)
|
||||||
|
if args.copy and len(blocks):
|
||||||
|
pyperclip.copy(blocks[0])
|
||||||
|
|
||||||
|
def handle_interactive_input(self, args):
|
||||||
|
print("\033[91massistant\033[0m: Type your message (press Ctrl+D to send):")
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
full_input = self.improved_input()
|
||||||
|
if full_input is None:
|
||||||
|
break
|
||||||
|
if full_input.strip() == '':
|
||||||
|
continue
|
||||||
|
print("\033[91massistant\033[0m: ", end='')
|
||||||
|
result = self.assistant.chat(full_input)
|
||||||
|
print()
|
||||||
|
except (EOFError, KeyboardInterrupt):
|
||||||
|
print("\nExiting...")
|
||||||
|
break
|
||||||
|
|
||||||
|
def improved_input(self):
|
||||||
|
lines = []
|
||||||
|
print("> ", end='', flush=True)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
line = sys.stdin.readline()
|
||||||
|
if not line:
|
||||||
|
print(flush=True)
|
||||||
|
break
|
||||||
|
if line.strip() == '':
|
||||||
|
continue
|
||||||
|
lines.append(line)
|
||||||
|
if len(lines) > 1 and lines[-2] == '\n' and lines[-1] == '\n':
|
||||||
|
break
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nUser aborted input")
|
||||||
|
return None
|
||||||
|
full_input = ''.join(lines).rstrip('\n')
|
||||||
|
return full_input
|
||||||
|
|
||||||
|
def extract_code_block(self, text):
|
||||||
|
pattern = r'```[a-z]*\n[\s\S]*?\n```'
|
||||||
|
code_blocks = []
|
||||||
|
matches = re.finditer(pattern, text)
|
||||||
|
for match in matches:
|
||||||
|
code_block = match.group(0)
|
||||||
|
lexer_name = self.determine_lexer(code_block)
|
||||||
|
highlighted_code = self.highlight_code(lexer_name, code_block)
|
||||||
|
code_blocks.append(highlighted_code)
|
||||||
|
if not code_blocks:
|
||||||
|
line_pattern = r'`[a-z]*[\s\S]*?`'
|
||||||
|
line_matches = re.finditer(line_pattern, text)
|
||||||
|
for match in line_matches:
|
||||||
|
code_block = match.group(0)
|
||||||
|
code_blocks.append(code_block[1:-1])
|
||||||
|
return code_blocks
|
||||||
|
|
||||||
|
def determine_lexer(self, code_block):
|
||||||
|
lexer_name = None
|
||||||
|
lines = code_block.split('\n')
|
||||||
|
for line in lines:
|
||||||
|
if line.strip().startswith('```'):
|
||||||
|
lexer_part = line.strip().split('```')[1].strip()
|
||||||
|
if lexer_part:
|
||||||
|
lexer_name = lexer_part
|
||||||
|
break
|
||||||
|
elif line.strip().startswith('lang:'):
|
||||||
|
lexer_part = line.strip().split(':')[1].strip()
|
||||||
|
if lexer_part:
|
||||||
|
lexer_name = lexer_part
|
||||||
|
break
|
||||||
|
return lexer_name
|
||||||
|
|
||||||
|
def highlight_code(self, lexer_name, code):
|
||||||
|
try:
|
||||||
|
lexer = get_lexer_by_name(lexer_name) if lexer_name else guess_lexer(code)
|
||||||
|
except ValueError:
|
||||||
|
lexer = guess_lexer('\n'.join(code.split('\n')[1:-1]))
|
||||||
|
if not lexer:
|
||||||
|
lexer = get_lexer_by_name('bash')
|
||||||
|
formatter = TerminalFormatter()
|
||||||
|
highlighted_code = pygments.highlight(code, lexer, formatter)
|
||||||
|
return highlighted_code
|
||||||
|
|
||||||
|
class CommandParser:
|
||||||
|
def __init__(self):
|
||||||
|
self.commands = {
|
||||||
|
'/save': self.handle_save,
|
||||||
|
'/clear': self.handle_clear,
|
||||||
|
'/clipboard': self.handle_clipboard,
|
||||||
|
'/exit': self.handle_exit
|
||||||
|
}
|
||||||
|
|
||||||
|
def parse_commands(self, text):
|
||||||
|
tokens = text.split(' ')
|
||||||
|
if not tokens:
|
||||||
|
return False
|
||||||
|
command = tokens[0]
|
||||||
|
if command in self.commands:
|
||||||
|
handler = self.commands[command]
|
||||||
|
if len(tokens) > 1 and command == '/clipboard':
|
||||||
|
context_query = '\n\nThe following is context provided by the user:\n'
|
||||||
|
clipboard_content = pyperclip.paste()
|
||||||
|
if clipboard_content:
|
||||||
|
context_query += clipboard_content + '\n'
|
||||||
|
return handler(context_query)
|
||||||
|
else:
|
||||||
|
handler()
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def handle_save(self):
|
||||||
|
filename = input('Enter filename to save conversation: ')
|
||||||
|
self.save_conversation(filename)
|
||||||
|
|
||||||
|
def save_conversation(self, filename='conversation.md'):
|
||||||
|
if not filename.endswith('.md'):
|
||||||
|
filename += '.md'
|
||||||
|
base, extension = os.path.splitext(filename)
|
||||||
|
i = 1
|
||||||
|
while os.path.exists(filename):
|
||||||
|
filename = f"{base}_{i}{extension}"
|
||||||
|
i += 1
|
||||||
|
with open(filename, 'w') as f:
|
||||||
|
f.write(conversation)
|
||||||
|
|
||||||
|
def handle_clear(self):
|
||||||
|
self.assistant.history = [self.assistant.system_prompt()]
|
||||||
|
self.assistant.save_history()
|
||||||
|
|
||||||
|
def handle_clipboard(self, context_query):
|
||||||
|
# Implementation for clipboard command
|
||||||
|
pass
|
||||||
|
|
||||||
|
def handle_exit(self):
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = CommandLineParser().parse()
|
||||||
|
assistant = AIAssistant()
|
||||||
|
if args.host:
|
||||||
|
assistant.set_host(args.host)
|
||||||
|
if args.model:
|
||||||
|
assistant.model = args.model
|
||||||
|
if args.temp:
|
||||||
|
assistant.temperature = args.temp
|
||||||
|
if args.context:
|
||||||
|
assistant.num_ctx = args.context
|
||||||
|
if args.new:
|
||||||
|
assistant.history = [assistant.system_prompt()]
|
||||||
|
assistant.save_history()
|
||||||
|
|
||||||
|
input_handler = InputHandler(assistant)
|
||||||
|
input_handler.handle_input(args)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
Loading…
Reference in a new issue