-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathlanguage_learning.lua
154 lines (130 loc) · 3.96 KB
/
language_learning.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
local make_map = require 'common.make_map'
local pickups = require 'common.pickups'
local screen_message = require 'common.screen_message'
local tensor = require 'dmlab.system.tensor'
local random = require 'common.random'
local custom_observations = require 'decorators.custom_observations'
local timeout = require 'decorators.timeout'
local api = {}
local entityLayer = [[
**************
* *
* *
* *
***** *****
***** *****
***** *****
* *
* *
* *
* P *
**************
]]
local variationsLayer = [[
..............
..............
..............
..............
..............
..............
..............
..............
..............
..............
]]
random.seed(32)
------------------------------------------------------------------
local object_types = {'A', 'F', 'L', 'S'}
local layout_types = {'A', 'B', 'C'}
local object_map = {['A'] = 'apple', ['L'] = 'lemon',
['F'] = 'fungi', ['S'] = 'strawberry'}
local word2id = {['find']=0, ['apple']=1, ['fungi']=2, ['lemon']=3,
['strawberry']=4, ['pad']=5, ['start']=6, ['stop']=7}
local id2word = {[0] = 'find', [1] = 'apple', [2] = 'fungi', [3] = 'lemon',
[4] = 'strawberry', [5] = 'pad', [6] = 'start', [7] = 'stop'}
-----------------------------------------------------------------
-- Build the map with random positioning of objects
local present_objects = {}
for i = 1, #entityLayer do
if (entityLayer:sub(i, i) == ' ' and random.uniformReal(0, 1) < 0.4) then
local object_to_pick = random.uniformInt(1, 4)
entityLayer = entityLayer:sub(1, i-1) .. object_types[object_to_pick] .. entityLayer:sub(i+1)
-- Include present objects
if present_objects[object_to_pick] == nil then
present_objects[object_to_pick] = 'Yay'
end
end
end
-- Build the command in text and index encoded
local command = {}
local keyset={}
-- Build list of present objects
for k,v in pairs(present_objects ) do
table.insert(keyset, k)
end
-- Choose a random object type from the present objects
local object_to_find = random.uniformInt(1, #keyset)
command['command'] = tensor.Tensor{6, 0, keyset[object_to_find], 7}
command['text'] = 'find ' .. id2word[keyset[object_to_find]]
-- Set the score of the game based on the chosen object
-- Chosen object has positive score, whereas every other object types have negative score
local object = object_types[keyset[object_to_find]]
local reward = ''
local rewards = {'apple_reward', 'fungi_reward', 'strawberry_reward', 'lemon_reward'}
if object == 'A' then
reward = 'apple_reward'
elseif object == 'F' then
reward = 'fungi_reward'
elseif object == 'S' then
reward = 'strawberry_reward'
elseif object == 'L' then
reward = 'lemon_reward'
end
for i = 1, 4 do
if pickups.defaults[rewards[i]]['class_name'] == reward then
pickups.defaults[rewards[i]]['quantity'] = 1
else
pickups.defaults[rewards[i]]['quantity'] = -1
end
end
-------------------------------------------------------------------
print('LEVEL SHAPE')
print(entityLayer)
print('Reward object', reward)
function api:start(episode, seed)
make_map.seedRng(seed)
api._count = 0
end
local observationTable = {
ORDER = command.command,
}
function api:customObservationSpec()
return {
{name = 'ORDER', type = 'Doubles', shape = {4}},
}
end
function api:customObservation(name)
return observationTable[name]
end
function api:commandLine(oldCommandLine)
return make_map.commandLine(oldCommandLine)
end
function api:createPickup(className)
return pickups.defaults[className]
end
function api:nextMap()
api._count = api._count + 1
return make_map.makeMap("demo_map_" .. api._count, entityLayer, variationsLayer)
end
function api:screenMessages(args)
local message_order = {
message = command.text,
x = 0,
y = 0,
alignment = screen_message.ALIGN_LEFT,
}
return { message_order }
end
-- Set timeout for 4 minutes
timeout.decorate(api, 4 * 60)
return api