Compare commits

...

3 commits

2 changed files with 121 additions and 85 deletions

View file

@ -4,6 +4,9 @@ from ollama import Client
import re
import pyperclip
import sys
import tty
import termios
import signal
import argparse
import pygments
from pygments.lexers import get_lexer_by_name, guess_lexer
@ -11,6 +14,9 @@ from pygments.formatters import TerminalFormatter
import os
import json
from prompt_toolkit.key_binding import KeyBindings
from prompt_toolkit import PromptSession
server = 'localhost:11434'
model = 'gemma3:12b'
reasoning_model='deepseek-r1:14b'
@ -344,37 +350,24 @@ def handle_piped_input(args):
if args.copy and len(blocks):
copy_string_to_clipboard(blocks[0])
def improved_input():
kb = KeyBindings()
@kb.add('c-d')
def _(event):
event.current_buffer.validate_and_handle()
session = PromptSession(multiline=True, prompt_continuation='', key_bindings=kb)
def improved_input(prompt="> "):
"""
Handles multi-line input and prevents accidental sending of pasted content.
Returns the complete input text when the user indicates they're done.
Returns the full text (including embedded newlines) when you press Ctrl-D.
Arrow keys edit within or across lines automatically.
"""
lines = []
print("> ", end='', flush=True)
while True:
try:
line = sys.stdin.readline()
if not line:
print("", flush=True) # Some visual feedback that user input has ended
break # EOF
if line.strip() == '':
continue # ignore empty lines
lines.append(line)
# Check for termination signal (Ctrl+D or two consecutive empty lines)
if len(lines) > 1 and lines[-2] == '\n' and lines[-1] == '\n':
break
except KeyboardInterrupt:
print("\nUser aborted input")
return None
except Exception as e:
print(f"Error reading input: {e}")
return None
# Join the lines and strip any trailing newlines
full_input = ''.join(lines).rstrip('\n')
return full_input
try:
text = session.prompt(prompt)
return text
except KeyboardInterrupt:
print("\nUser aborted input")
return None
def handle_non_piped_input(args):
if args.shell:

View file

@ -1,13 +1,20 @@
#/bin/python
# AI Assistant in the terminal
import argparse
import os
import sys
import json
from ollama import Client
import re
import pyperclip
import sys
import argparse
import pygments
from pygments.lexers import get_lexer_by_name, guess_lexer
from pygments.formatters import TerminalFormatter
import os
import json
from prompt_toolkit.key_binding import KeyBindings
from prompt_toolkit import PromptSession
class AIAssistant:
def __init__(self, server="http://localhost:11434", model="gemma3:12b"):
@ -17,7 +24,6 @@ class AIAssistant:
self.temperature = 0.2
self.num_ctx = 4096
self.history = [self.system_prompt()]
self.load_history()
def set_host(self, host):
self.server = host
@ -39,6 +45,33 @@ class AIAssistant:
with open(path, 'w+') as f:
json.dump(self.history, f)
def determine_lexer(self, code_block):
lexer_name = None
lines = code_block.split('\n')
for line in lines:
if line.strip().startswith('```'):
lexer_part = line.strip().split('```')[1].strip()
if lexer_part:
lexer_name = lexer_part
break
elif line.strip().startswith('lang:'):
lexer_part = line.strip().split(':')[1].strip()
if lexer_part:
lexer_name = lexer_part
break
return lexer_name
def highlight_code(self, lexer_name, code):
try:
lexer = get_lexer_by_name(lexer_name) if lexer_name else guess_lexer(code)
except ValueError:
lexer = guess_lexer('\n'.join(code.split('\n')[1:-1]))
if not lexer:
lexer = get_lexer_by_name('bash')
formatter = TerminalFormatter()
highlighted_code = pygments.highlight(code, lexer, formatter)
return highlighted_code
def chat(self, message, stream=True):
self.history.append({"role": "user", "content": message})
completion = self.client.chat(
@ -49,11 +82,28 @@ class AIAssistant:
)
result = ''
large_chunk = []
language = None
for chunk in completion:
text = chunk['message']['content']
large_chunk.append(text)
large_text = ''.join(large_chunk)
if ('```' in large_text) and ('\n' in large_text.split('```')[1]):
language = large_text.split('```')[1].split('\n')[0]
large_chunk = []
if language == '':
print(large_text, end='', flush=True)
language = None
if stream:
print(text, end='', flush=True)
if language and ('\n' in large_text) and large_chunk:
output = self.highlight_code(language, large_text)
print(output, end='', flush=True)
large_chunk = []
elif not language or not large_chunk:
print(text, end='', flush=True)
if not stream:
result = completion['message']['content']
else:
@ -82,9 +132,17 @@ class CommandLineParser:
def parse(self):
return self.parser.parse_args()
# Keybindings
bindings = KeyBindings()
@bindings.add('c-d')
def _(event):
event.current_buffer.validate_and_handle()
class InputHandler:
def __init__(self, assistant):
self.assistant = assistant
self.session = PromptSession(multiline=True, prompt_continuation='', key_bindings=bindings)
def handle_input(self, args):
if not sys.stdin.isatty():
@ -98,14 +156,33 @@ class InputHandler:
if args.copy:
query += 'Answer the question using a codeblock for any code or shell scripts\n'
if args.follow_up:
second_input = input('> ')
second_input = improved_input()
query += f'\n{second_input}'
result = self.assistant.chat(query, stream=False)
blocks = self.extract_code_block(result)
if args.copy and len(blocks):
pyperclip.copy(blocks[0])
def arg_shell(args):
query = '''
Form a shell command based on the following description. Only output a working shell command. Format the command like this: `command`
Description:
'''
if args.shell != True:
query += args.shell
else:
query += self.improved_input()
result = self.assistant.chat(query, stream=False)
result = blocks[0] if len(blocks := self.extract_code_block(result)) else result
print(result)
copy_string_to_clipboard(result)
def handle_interactive_input(self, args):
if args.shell:
self.arg_shell(args)
exit()
print("\033[91massistant\033[0m: Type your message (press Ctrl+D to send):")
while True:
try:
@ -114,32 +191,23 @@ class InputHandler:
break
if full_input.strip() == '':
continue
print("\033[91massistant\033[0m: ", end='')
result = self.assistant.chat(full_input)
print()
except (EOFError, KeyboardInterrupt):
print("\nExiting...")
break
def improved_input(self):
lines = []
print("> ", end='', flush=True)
while True:
try:
line = sys.stdin.readline()
if not line:
print(flush=True)
break
if line.strip() == '':
continue
lines.append(line)
if len(lines) > 1 and lines[-2] == '\n' and lines[-1] == '\n':
break
except KeyboardInterrupt:
print("\nUser aborted input")
return None
full_input = ''.join(lines).rstrip('\n')
return full_input
def improved_input(self, prompt="> "):
"""
Returns the full text (including embedded newlines) when you press Ctrl-D.
Arrow keys edit within or across lines automatically.
"""
try:
text = self.session.prompt(prompt)
return text
except KeyboardInterrupt:
print("\nUser aborted input")
return None
def extract_code_block(self, text):
pattern = r'```[a-z]*\n[\s\S]*?\n```'
@ -147,8 +215,8 @@ class InputHandler:
matches = re.finditer(pattern, text)
for match in matches:
code_block = match.group(0)
lexer_name = self.determine_lexer(code_block)
highlighted_code = self.highlight_code(lexer_name, code_block)
lexer_name = self.assistant.determine_lexer(code_block)
highlighted_code = self.assistant.highlight_code(lexer_name, code_block)
code_blocks.append(highlighted_code)
if not code_blocks:
line_pattern = r'`[a-z]*[\s\S]*?`'
@ -158,33 +226,6 @@ class InputHandler:
code_blocks.append(code_block[1:-1])
return code_blocks
def determine_lexer(self, code_block):
lexer_name = None
lines = code_block.split('\n')
for line in lines:
if line.strip().startswith('```'):
lexer_part = line.strip().split('```')[1].strip()
if lexer_part:
lexer_name = lexer_part
break
elif line.strip().startswith('lang:'):
lexer_part = line.strip().split(':')[1].strip()
if lexer_part:
lexer_name = lexer_part
break
return lexer_name
def highlight_code(self, lexer_name, code):
try:
lexer = get_lexer_by_name(lexer_name) if lexer_name else guess_lexer(code)
except ValueError:
lexer = guess_lexer('\n'.join(code.split('\n')[1:-1]))
if not lexer:
lexer = get_lexer_by_name('bash')
formatter = TerminalFormatter()
highlighted_code = pygments.highlight(code, lexer, formatter)
return highlighted_code
class CommandParser:
def __init__(self):
self.commands = {
@ -252,6 +293,8 @@ def main():
if args.new:
assistant.history = [assistant.system_prompt()]
assistant.save_history()
else:
assistant.load_history()
input_handler = InputHandler(assistant)
input_handler.handle_input(args)