This repository was archived by the owner on May 22, 2020. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
38 lines (36 loc) · 1.38 KB
/
utils.py
File metadata and controls
38 lines (36 loc) · 1.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def loadLines(fileName, fields):
"""From PyTorch Tutorial"""
lines = {}
with open(fileName, 'r', encoding='iso-8859-1') as f:
for line in f:
values = line.split(" +++$+++ ")
lineObj = {}
for i, field in enumerate(fields):
lineObj[field] = values[i]
lines[lineObj['lineID']] = lineObj
return lines
def loadConversations(fileName, lines, fields):
"""From PyTorch Tutorial"""
conversations = []
with open(fileName, 'r', encoding='iso-8859-1') as f:
for line in f:
values = line.split(" +++$+++ ")
convObj = {}
for i, field in enumerate(fields):
convObj[field] = values[i]
lineIds = eval(convObj["utteranceIDs"])
convObj["lines"] = []
for lineId in lineIds:
convObj["lines"].append(lines[lineId])
conversations.append(convObj)
return conversations
def extractSentencePairs(conversations):
"""From PyTorch Tutorial"""
qa_pairs = []
for conversation in conversations:
for i in range(len(conversation["lines"]) - 1):
inputLine = conversation["lines"][i]["text"].strip()
targetLine = conversation["lines"][i+1]["text"].strip()
if inputLine and targetLine:
qa_pairs.append([inputLine, targetLine])
return qa_pairs