File size: 1,578 Bytes
3860419
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
"""
Tests the collect_learnings function in the cli/collect module.
"""

import pytest

# def test_collect_learnings(monkeypatch):
#     monkeypatch.setattr(rudder_analytics, "track", MagicMock())
#
#     model = "test_model"
#     temperature = 0.5
#     steps = [simple_gen]
#     dbs = FileRepositories(
#         OnDiskRepository("/tmp"),
#         OnDiskRepository("/tmp"),
#         OnDiskRepository("/tmp"),
#         OnDiskRepository("/tmp"),
#         OnDiskRepository("/tmp"),
#         OnDiskRepository("/tmp"),
#         OnDiskRepository("/tmp"),
#     )
#     dbs.input = {
#         "prompt": "test prompt\n with newlines",
#         "feedback": "test feedback",
#     }
#     code = "this is output\n\nit contains code"
#     dbs.logs = {steps[0].__name__: json.dumps([{"role": "system", "content": code}])}
#     dbs.memory = {"all_output.txt": "test workspace\n" + code}
#
#     collect_learnings(model, temperature, steps, dbs)
#
#     learnings = extract_learning(
#         model, temperature, steps, dbs, steps_file_hash=steps_file_hash()
#     )
#     assert rudder_analytics.track.call_count == 1
#     assert rudder_analytics.track.call_args[1]["event"] == "learning"
#     a = {
#         k: v
#         for k, v in rudder_analytics.track.call_args[1]["properties"].items()
#         if k != "timestamp"
#     }
#     b = {k: v for k, v in learnings.to_dict().items() if k != "timestamp"}
#     assert a == b
#
#     assert json.dumps(code) in learnings.logs
#     assert code in learnings.workspace


if __name__ == "__main__":
    pytest.main(["-v"])