Skip to content

Commit 75a9edd

Browse files
committed
[WIP] predict next words
1 parent db5db28 commit 75a9edd

File tree

2 files changed

+102
-0
lines changed

2 files changed

+102
-0
lines changed

src/InputController.h

+4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#import "ConversionEngine.h"
66

77
@interface InputController : IMKInputController {
8+
NSMutableString *_sentenceBuffer;
89
NSMutableString *_composedBuffer;
910
NSMutableString *_originalBuffer;
1011
NSInteger _insertionIndex;
@@ -17,6 +18,9 @@
1718
AnnotationWinController *_annotationWin;
1819
}
1920

21+
- (NSMutableString *)sentenceBuffer;
22+
- (void)setSentenceBuffer:(NSString *)string;
23+
2024
- (NSMutableString *)composedBuffer;
2125
- (void)setComposedBuffer:(NSString *)string;
2226
- (NSMutableString *)originalBuffer;

src/InputController.mm

+98
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ - (BOOL)onKeyEvent:(NSEvent *)event client:(id)sender {
173173
if (hasBufferedText) {
174174
[self appendToComposedBuffer:characters];
175175
[self commitCompositionWithoutSpace:sender];
176+
[self setSentenceBuffer: @""];
176177
return YES;
177178
}
178179
}
@@ -228,6 +229,27 @@ - (void)commitComposition:(id)sender {
228229
[sender insertText:text replacementRange:NSMakeRange(NSNotFound, NSNotFound)];
229230

230231
[self reset];
232+
233+
NSLog(@"Current Sentence Buffer: %@", self.sentenceBuffer);
234+
if ([self doesSentenceBufferIncludeSpace]) {
235+
[self fetchPredictionsForText:self.sentenceBuffer completion:^(NSDictionary *responseDict, NSArray *bertArray, NSError *error) {
236+
if (error) {
237+
NSLog(@"Error: %@", error.localizedDescription);
238+
} else {
239+
NSLog(@"BERT: %@", bertArray);
240+
dispatch_async(dispatch_get_main_queue(), ^{
241+
[sharedCandidates setCandidateData:bertArray];
242+
[sharedCandidates show:kIMKLocateCandidatesBelowHint];
243+
});
244+
}
245+
}];
246+
}
247+
248+
}
249+
250+
- (BOOL)doesSentenceBufferIncludeSpace {
251+
NSRange range = [self.sentenceBuffer rangeOfString:@" "];
252+
return range.location != NSNotFound;
231253
}
232254

233255
- (void)commitCompositionWithoutSpace:(id)sender {
@@ -242,6 +264,66 @@ - (void)commitCompositionWithoutSpace:(id)sender {
242264
[self reset];
243265
}
244266

267+
- (NSString *) fetchAPIURL {
268+
NSUserDefaults *defaults = [NSUserDefaults standardUserDefaults];
269+
NSString *apiURL = [defaults stringForKey:@"NEXT_WORD_PREDICTION_SERVICE_URL"];
270+
if (apiURL) {
271+
return apiURL;
272+
} else {
273+
return @"http://127.0.0.1:8080/get_end_predictions";
274+
}
275+
}
276+
277+
- (void)fetchPredictionsForText:(NSString *)text completion:(void(^)(NSDictionary *responseDict, NSArray *bertArray, NSError *error))completionHandler {
278+
NSString *urlString = [self fetchAPIURL];
279+
NSURL *url = [NSURL URLWithString:urlString];
280+
NSMutableURLRequest *request = [NSMutableURLRequest requestWithURL:url];
281+
request.HTTPMethod = @"POST";
282+
[request setValue:@"application/json" forHTTPHeaderField:@"Content-Type"];
283+
284+
NSDictionary *jsonBody = @{@"input_text": text, @"top_k": @"9"};
285+
NSError *jsonError;
286+
NSData *jsonData = [NSJSONSerialization dataWithJSONObject:jsonBody options:0 error:&jsonError];
287+
288+
if (jsonError) {
289+
completionHandler(nil, nil, jsonError);
290+
return;
291+
}
292+
293+
request.HTTPBody = jsonData;
294+
295+
NSURLSession *session = [NSURLSession sharedSession];
296+
NSURLSessionDataTask *task = [session dataTaskWithRequest:request completionHandler:^(NSData *data, NSURLResponse *response, NSError *error) {
297+
if (error) {
298+
completionHandler(nil, nil, error);
299+
return;
300+
}
301+
302+
NSError *jsonParsingError;
303+
NSDictionary *responseDict = [NSJSONSerialization JSONObjectWithData:data options:0 error:&jsonParsingError];
304+
305+
if (jsonParsingError) {
306+
completionHandler(nil, nil, jsonParsingError);
307+
} else {
308+
NSArray *bertArray = nil;
309+
NSArray *bertCNArray = nil;
310+
311+
// Parsing the bert string
312+
NSString *bertString = [responseDict objectForKey:@"bert"];
313+
if (bertString) {
314+
bertArray = [bertString componentsSeparatedByString:@"\n"];
315+
bertArray = [bertArray filteredArrayUsingPredicate:[NSPredicate predicateWithBlock:^BOOL(id evaluatedObject, NSDictionary *bindings) {
316+
return [evaluatedObject length] > 0 && ![evaluatedObject isEqualToString:@"[UNK]"];
317+
}]];
318+
}
319+
320+
completionHandler(responseDict, bertArray, nil);
321+
}
322+
}];
323+
324+
[task resume];
325+
}
326+
245327
- (void)reset {
246328
[self setComposedBuffer:@""];
247329
[self setOriginalBuffer:@""];
@@ -264,6 +346,22 @@ - (NSMutableString *)composedBuffer {
264346

265347
- (void)setComposedBuffer:(NSString *)string {
266348
NSMutableString *buffer = [self composedBuffer];
349+
if (string && string.length > 0) {
350+
NSString * sentence = self.sentenceBuffer;
351+
[self setSentenceBuffer: [NSString stringWithFormat:@"%@ %@", sentence, string]];
352+
}
353+
[buffer setString:string];
354+
}
355+
356+
- (NSMutableString *)sentenceBuffer {
357+
if (_sentenceBuffer == nil) {
358+
_sentenceBuffer = [[NSMutableString alloc] init];
359+
}
360+
return _sentenceBuffer;
361+
}
362+
363+
- (void)setSentenceBuffer:(NSString *)string {
364+
NSMutableString *buffer = [self sentenceBuffer];
267365
[buffer setString:string];
268366
}
269367

0 commit comments

Comments
 (0)