Skip to content

Commit

Permalink
make prompt prefix and postfix configurable
Browse files Browse the repository at this point in the history
fix NPE
enable clients send prompt direct with IQ message
  • Loading branch information
deleolajide committed Jan 1, 2024
1 parent 2627ae5 commit 9290f41
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 33 deletions.
6 changes: 3 additions & 3 deletions src/java/org/ifsoft/llama/openfire/LLaMA.java
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ public void initializePlugin(final PluginManager manager, final File pluginDirec
if (llamaHosted) {
createLLaMAUser();
loginLLaMAUser(true);
llamaConnection.handlePrediction("what is your name?", null, null);
llamaConnection.handlePrediction("what is your name?", null, null, null);
} else {
setupLLaMA(pluginDirectory);
}
Expand Down Expand Up @@ -181,7 +181,7 @@ public void onOutputLine(final String line) {
if (line.contains("HTTP server listening") && llamaConnection != null) {
Log.info("Sending test data to LLaMA");

llamaConnection.handlePrediction("what is your name?", null, null);
llamaConnection.handlePrediction("what is your name?", null, null, null);
}
}

Expand Down Expand Up @@ -447,7 +447,7 @@ public void messageReceived(JID roomJID, JID user, String nickname, Message mess
Thread.sleep(1000);
}

//llamaConnection.handlePrediction(body, roomJID, message.getType());
//llamaConnection.handlePrediction(body, roomJID, message.getType(), null);
}
} catch (Exception e) {
Log.error("unable to handle groupchat message", e);
Expand Down
85 changes: 55 additions & 30 deletions src/java/org/ifsoft/llama/openfire/LLaMAConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ public LLaMAConnection(String username, String remoteUrl) {
}
}

public void handlePrediction(final String prompt, final JID requestor, final Message.Type chatType) {
public void handlePrediction(final String prompt, final JID requestor, final Message.Type chatType, final IQ reply) {
exec.execute(new Runnable() {
public void run() {
long threadId = Thread.currentThread().getId()% LLaMA.numThreads;
Expand All @@ -125,32 +125,44 @@ public void run() {
"assistant_name": "Assistant:"
}
*/
double temperature = 0.5;
double top_p = 0.9;

try {
temperature = Double.parseDouble(JiveGlobals.getProperty("llama.temperature", "0.5"));
top_p = Double.parseDouble(JiveGlobals.getProperty("llama.top.p.sampling", "0.9"));
} catch (Exception e) {
Log.error("Unable to set temperature or top_p", e);
}
JSONObject systemPrompt = new JSONObject();
systemPrompt.put("prompt", JiveGlobals.getProperty("llama.system.prompt", LLaMA.getSystemPrompt()));
systemPrompt.put("anti_prompt", "User:");
systemPrompt.put("assistant_name", alias);

JSONObject testData = new JSONObject();
testData.put("system_prompt", systemPrompt);
testData.put("prompt", "[INST]" + prompt + "[/INST]");
testData.put("n_predict", JiveGlobals.getIntProperty("llama.predictions", 256));
testData.put("stream", true);
testData.put("cache_prompt", JiveGlobals.getBooleanProperty("llama.cache.prompt", true));
testData.put("slot_id", threadId);
testData.put("temperature", temperature);
testData.put("top_k", JiveGlobals.getIntProperty("llama.top.k.sampling", 40));
testData.put("top_p", top_p);
JSONObject testData = new JSONObject();
double temperature = Double.parseDouble(JiveGlobals.getProperty("llama.temperature", "0.5"));
double top_p = Double.parseDouble(JiveGlobals.getProperty("llama.top.p.sampling", "0.9"));

JSONObject systemPrompt = new JSONObject();
systemPrompt.put("prompt", JiveGlobals.getProperty("llama.system.prompt", LLaMA.getSystemPrompt()));
systemPrompt.put("anti_prompt", "User:");
systemPrompt.put("assistant_name", alias);

testData.put("system_prompt", systemPrompt);
testData.put("prompt", JiveGlobals.getProperty("llama.prefix.prompt", "[INST]") + prompt + JiveGlobals.getProperty("llama.postfix.prompt", "[/INST]"));
testData.put("n_predict", JiveGlobals.getIntProperty("llama.predictions", 256));
testData.put("stream", true);
testData.put("cache_prompt", JiveGlobals.getBooleanProperty("llama.cache.prompt", true));
testData.put("slot_id", threadId);
testData.put("temperature", temperature);
testData.put("top_k", JiveGlobals.getIntProperty("llama.top.k.sampling", 40));
testData.put("top_p", top_p);


if (reply != null) {
final String response = getJson("/completion", testData, null, null);
reply.setChildElement("response", "urn:xmpp:gen-ai:0").setText(response);
XMPPServer.getInstance().getRoutingTable().routePacket(reply.getTo(), reply, true);
}
else {
getJson("/completion", testData, requestor, chatType);
}

getJson("/completion", testData, requestor, chatType);
} catch (Exception e) {
Log.error("Unable to run infer prompt", e);

if (reply != null) {
reply.setError(new PacketError(PacketError.Condition.internal_server_error, PacketError.Type.modify, e.toString()));
}
}
}
});
}
Expand Down Expand Up @@ -268,11 +280,11 @@ public void deliver(Packet packet) throws UnauthorizedException {

if (msg.toLowerCase().startsWith(llamaUser.toLowerCase())) {
requestor = new JID(packet.getFrom().toBareJID());
handlePrediction(msg, requestor, message.getType());
handlePrediction(msg, requestor, message.getType(), null);
}

} else {
handlePrediction(msg, requestor, message.getType());
handlePrediction(msg, requestor, message.getType(), null);
}
}
}
Expand All @@ -283,8 +295,17 @@ public void deliver(Packet packet) throws UnauthorizedException {
IQ iq = (IQ) packet;
Log.debug("Incoming IQ " + packet.getFrom() + " " + iq.getType());

IQ reply = IQ.createResultIQ(iq);
XMPPServer.getInstance().getRoutingTable().routePacket(packet.getFrom(), reply, true);
if (iq.getType() != IQ.Type.result) {
IQ reply = IQ.createResultIQ(iq);
final Element element = iq.getChildElement();

if (element != null && element.getNamespaceURI().equals("urn:xmpp:gen-ai:0")) {
handlePrediction(element.getText(), null, null, reply);
return;
}

XMPPServer.getInstance().getRoutingTable().routePacket(packet.getFrom(), reply, true);
}
}
}

Expand Down Expand Up @@ -332,11 +353,12 @@ public boolean isCompressed() {
//
//-------------------------------------------------------

private void getJson(String urlToRead, JSONObject data, JID requestor, Message.Type chatType) {
private String getJson(String urlToRead, JSONObject data, JID requestor, Message.Type chatType) {
URL url;
HttpURLConnection conn;
BufferedReader rd;
String line;
String accumulator = "";
StringBuilder result = new StringBuilder();

String llamaHost = JiveGlobals.getProperty("llama.host", hostname);
Expand Down Expand Up @@ -385,6 +407,7 @@ private void getJson(String urlToRead, JSONObject data, JID requestor, Message.T

if (requestor != null && !isNull(msg)) {
replyChat(msg, requestor, chatType);
accumulator = accumulator + msg;
}

} else {
Expand All @@ -401,7 +424,8 @@ private void getJson(String urlToRead, JSONObject data, JID requestor, Message.T
Log.info("getJson - chat\n" + msg);

if (requestor != null) {
replyChat(msg, requestor, chatType);
replyChat(msg, requestor, chatType);
accumulator = accumulator + msg;
}
}
} else {
Expand All @@ -414,6 +438,7 @@ private void getJson(String urlToRead, JSONObject data, JID requestor, Message.T
} catch (Exception e) {
Log.error("getJson", e);
}
return accumulator;
}

private void replyState(String state, JID requestor, Message.Type chatType) {
Expand Down

0 comments on commit 9290f41

Please sign in to comment.